-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_model.jl
48 lines (36 loc) · 1.24 KB
/
run_model.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
using CSV
using DataFrames
using GLM
using ArgParse
using Serialization
# Set up argument parsing
function parse_commandline()
s = ArgParseSettings()
@add_arg_table s begin
"--input"
help = "Input directory"
default = "/input"
"--output"
help = "Output directory"
default = "/output"
end
return parse_args(s)
end
# Parse command-line arguments
args = parse_commandline()
# Load the pre-trained model
model = deserialize("/usr/local/bin/model_test_SC1.jls")
# Make predictions on new data
new_df = CSV.read(joinpath(args["input"], "Leaderboard_beta_subchallenge1.csv"), DataFrame)
new_df = permutedims(new_df, 1, "Sample_ID")
ID = new_df[:, "Sample_ID"]
new_df = new_df[:, ["cg18478105", "cg09835024", "cg14361672", "cg01763666", "cg12950382", "cg02115394"]]
# Predict gestational age
predictions = predict(model, new_df)
# Ensure predictions are within a valid range
predictions[predictions .> 44] .= 44
predictions[predictions .< 5] .= 5
# Combine predictions with IDs
output_df = DataFrame(ID=ID, GA_prediction=predictions)
# Write the predictions to the output directory
CSV.write(joinpath(args["output"], "predictions.csv"), output_df)