Skip to content

Commit

Permalink
Merge pull request #589 from AMHermansen/update-training-examples
Browse files Browse the repository at this point in the history
Update training examples
  • Loading branch information
AMHermansen authored Sep 11, 2023
2 parents 75e925f + 052e407 commit 4f2f3f3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 62 deletions.
39 changes: 11 additions & 28 deletions examples/04_training/01_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pytorch_lightning.utilities import rank_zero_only
from graphnet.constants import EXAMPLE_OUTPUT_DIR
from graphnet.data.dataloader import DataLoader
from graphnet.models import Model
from graphnet.models import StandardModel
from graphnet.training.callbacks import ProgressBar
from graphnet.utilities.argparse import ArgumentParser
from graphnet.utilities.config import (
Expand All @@ -27,7 +27,6 @@ def main(
early_stopping_patience: int,
batch_size: int,
num_workers: int,
prediction_names: Optional[List[str]],
suffix: Optional[str] = None,
wandb: bool = False,
) -> None:
Expand All @@ -49,7 +48,7 @@ def main(

# Build model
model_config = ModelConfig.load(model_config_path)
model = Model.from_config(model_config, trust=True)
model: StandardModel = StandardModel.from_config(model_config, trust=True)

# Configuration
config = TrainingConfig(
Expand Down Expand Up @@ -102,39 +101,31 @@ def main(
**config.fit,
)

# Save model to file
db_name = dataset_config.path.split("/")[-1].split(".")[0]
path = os.path.join(archive, db_name, run_name)
os.makedirs(path, exist_ok=True)
logger.info(f"Writing results to {path}")
model.save_state_dict(f"{path}/state_dict.pth")
model.save(f"{path}/model.pth")

# Get predictions
if isinstance(config.target, str):
prediction_columns = [config.target + "_pred"]
additional_attributes = [config.target]
else:
prediction_columns = [target + "_pred" for target in config.target]
additional_attributes = config.target

if prediction_names:
prediction_columns = prediction_names

logger.info(f"config.target: {config.target}")
logger.info(f"prediction_columns: {prediction_columns}")
logger.info(f"prediction_columns: {model.prediction_labels}")

results = model.predict_as_dataframe(
dataloaders["test"],
prediction_columns=prediction_columns,
additional_attributes=additional_attributes + ["event_no"],
)

# Save predictions and model to file
db_name = dataset_config.path.split("/")[-1].split(".")[0]
path = os.path.join(archive, db_name, run_name)
logger.info(f"Writing results to {path}")
os.makedirs(path, exist_ok=True)

results.to_csv(f"{path}/results.csv")
model.save_state_dict(f"{path}/state_dict.pth")
model.save(f"{path}/model.pth")


if __name__ == "__main__":

# Parse command-line arguments
parser = ArgumentParser(
description="""
Expand All @@ -152,13 +143,6 @@ def main(
"num-workers",
)

parser.add_argument(
"--prediction-names",
nargs="+",
help="Names of each prediction output feature (default: %(default)s)",
default=None,
)

parser.add_argument(
"--suffix",
type=str,
Expand All @@ -182,7 +166,6 @@ def main(
args.early_stopping_patience,
args.batch_size,
args.num_workers,
args.prediction_names,
args.suffix,
args.wandb,
)
47 changes: 13 additions & 34 deletions examples/04_training/03_train_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from graphnet.data.dataloader import DataLoader
from graphnet.data.dataset import Dataset
from graphnet.models import Model
from graphnet.models import StandardModel
from graphnet.training.callbacks import ProgressBar
from graphnet.utilities.argparse import ArgumentParser
from graphnet.utilities.config import (
Expand All @@ -33,7 +33,6 @@ def main(
early_stopping_patience: int,
batch_size: int,
num_workers: int,
prediction_names: Optional[List[str]],
suffix: Optional[str] = None,
wandb: bool = False,
) -> None:
Expand All @@ -55,7 +54,7 @@ def main(

# Build model
model_config = ModelConfig.load(model_config_path)
model = Model.from_config(model_config, trust=True)
model: StandardModel = StandardModel.from_config(model_config, trust=True)

# Configuration
config = TrainingConfig(
Expand Down Expand Up @@ -107,15 +106,11 @@ def main(
# Log configurations to W&B
# NB: Only log to W&B on the rank-zero process in case of multi-GPU
# training.
if rank_zero_only == 0:
if wandb and rank_zero_only == 0:
wandb_logger.experiment.config.update(config)
wandb_logger.experiment.config.update(model_config.as_dict())
wandb_logger.experiment.config.update(dataset_config.as_dict())

# Build model
model_config = ModelConfig.load(model_config_path)
model = Model.from_config(model_config, trust=True)

# Training model
callbacks = [
EarlyStopping(
Expand All @@ -133,36 +128,28 @@ def main(
**config.fit,
)

# Save model to file
db_name = dataset_config.path.split("/")[-1].split(".")[0]
path = os.path.join(archive, db_name, run_name)
os.makedirs(path, exist_ok=True)
logger.info(f"Writing results to {path}")
model.save_state_dict(f"{path}/state_dict.pth")
model.save(f"{path}/model.pth")

# Get predictions
if isinstance(config.target, str):
prediction_columns = [
config.target + "_noise_pred",
config.target + "_muon_pred",
config.target + "_neutrino_pred",
]
additional_attributes = [config.target]
else:
prediction_columns = [target + "_pred" for target in config.target]
additional_attributes = config.target

if prediction_names:
prediction_columns = prediction_names
logger.info(f"config.target: {config.target}")
logger.info(f"prediction_columns: {model.prediction_labels}")

results = model.predict_as_dataframe(
test_dataloaders,
prediction_columns=prediction_columns,
additional_attributes=additional_attributes + ["event_no"],
)

# Save predictions and model to file
db_name = dataset_config.path.split("/")[-1].split(".")[0]
path = os.path.join(archive, db_name, run_name)
logger.info(f"Writing results to {path}")
os.makedirs(path, exist_ok=True)

results.to_csv(f"{path}/results.csv")
model.save_state_dict(f"{path}/state_dict.pth")
model.save(f"{path}/model.pth")


if __name__ == "__main__":
Expand Down Expand Up @@ -194,13 +181,6 @@ def main(
"num-workers",
)

parser.add_argument(
"--prediction-names",
nargs="+",
help="Names of each prediction output feature (default: %(default)s)",
default=["noise", "muon", "neutrino"],
)

parser.add_argument(
"--suffix",
type=str,
Expand All @@ -218,6 +198,5 @@ def main(
args.early_stopping_patience,
args.batch_size,
args.num_workers,
args.prediction_names,
args.suffix,
)

0 comments on commit 4f2f3f3

Please sign in to comment.