forked from graphnet-team/graphnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
106 lines (87 loc) · 2.95 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Simplified example of training Model."""
import os
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities import rank_zero_only
from graphnet.data.dataloader import DataLoader
from graphnet.models import Model
from graphnet.training.callbacks import ProgressBar
from graphnet.utilities.config import (
DatasetConfig,
ModelConfig,
TrainingConfig,
)
# Make sure W&B output directory exists
WANDB_DIR = "./wandb/"
os.makedirs(WANDB_DIR, exist_ok=True)
# Initialise Weights & Biases (W&B) run
wandb_logger = WandbLogger(
project="example-script",
entity="graphnet-team",
save_dir=WANDB_DIR,
log_model=True,
)
def main() -> None:
"""Run example."""
# Configuration
config = TrainingConfig(
target="energy",
early_stopping_patience=5,
fit={"gpus": [0, 1], "max_epochs": 5},
dataloader={"batch_size": 128, "num_workers": 10},
)
archive = "/groups/icecube/asogaard/gnn/results/"
run_name = "dynedge_{}_example".format(config.target)
# Construct dataloaders
dataset_config = DatasetConfig.load(
"configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml"
)
dataloaders = DataLoader.from_dataset_config(
dataset_config,
**config.dataloader,
)
# Build model
model_config = ModelConfig.load(f"configs/models/{run_name}.yml")
model = Model.from_config(model_config, trust=True)
# Log configurations to W&B
# NB: Only log to W&B on the rank-zero process in case of multi-GPU
# training.
if rank_zero_only == 0:
wandb_logger.experiment.config.update(config)
wandb_logger.experiment.config.update(model_config.as_dict())
wandb_logger.experiment.config.update(dataset_config.as_dict())
# Train model
callbacks = [
EarlyStopping(
monitor="val_loss",
patience=config.early_stopping_patience,
),
ProgressBar(),
]
model.fit(
dataloaders["train"],
dataloaders["validation"],
callbacks=callbacks,
logger=wandb_logger,
**config.fit,
)
# Get predictions
if isinstance(config.target, str):
prediction_columns = [config.target + "_pred"]
additional_attributes = [config.target]
else:
prediction_columns = [target + "_pred" for target in config.target]
additional_attributes = config.target
results = model.predict_as_dataframe(
dataloaders["test"],
prediction_columns=prediction_columns,
additional_attributes=additional_attributes + ["event_no"],
)
# Save predictions and model to file
db_name = dataset_config.path.split("/")[-1].split(".")[0]
path = os.path.join(archive, db_name, run_name)
results.to_csv(f"{path}/results.csv")
model.save_state_dict(f"{path}/state_dict.pth")
model.save(f"{path}/model.pth")
if __name__ == "__main__":
main()