-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #728 from RasmusOrsoe/jammy_flow_integration
Jammy flow integration
- Loading branch information
Showing
13 changed files
with
545 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
"""Example of training a conditional NormalizingFlow.""" | ||
|
||
import os | ||
from typing import Any, Dict, List, Optional | ||
|
||
from pytorch_lightning.loggers import WandbLogger | ||
import torch | ||
from torch.optim.adam import Adam | ||
|
||
from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR | ||
from graphnet.data.constants import FEATURES, TRUTH | ||
from graphnet.models.detector.prometheus import Prometheus | ||
from graphnet.models.gnn import DynEdge | ||
from graphnet.models.graphs import KNNGraph | ||
from graphnet.training.callbacks import PiecewiseLinearLR | ||
from graphnet.training.utils import make_train_validation_dataloader | ||
from graphnet.utilities.argparse import ArgumentParser | ||
from graphnet.utilities.logging import Logger | ||
from graphnet.utilities.imports import has_jammy_flows_package | ||
|
||
# Make sure the jammy flows is installed | ||
try: | ||
assert has_jammy_flows_package() | ||
from graphnet.models import NormalizingFlow | ||
except AssertionError: | ||
raise AssertionError( | ||
"This example requires the package`jammy_flow` " | ||
" to be installed. It appears that the package is " | ||
" not installed. Please install the package." | ||
) | ||
|
||
# Constants | ||
features = FEATURES.PROMETHEUS | ||
truth = TRUTH.PROMETHEUS | ||
|
||
|
||
def main( | ||
path: str, | ||
pulsemap: str, | ||
target: str, | ||
truth_table: str, | ||
gpus: Optional[List[int]], | ||
max_epochs: int, | ||
early_stopping_patience: int, | ||
batch_size: int, | ||
num_workers: int, | ||
wandb: bool = False, | ||
) -> None: | ||
"""Run example.""" | ||
# Construct Logger | ||
logger = Logger() | ||
|
||
# Initialise Weights & Biases (W&B) run | ||
if wandb: | ||
# Make sure W&B output directory exists | ||
wandb_dir = "./wandb/" | ||
os.makedirs(wandb_dir, exist_ok=True) | ||
wandb_logger = WandbLogger( | ||
project="example-script", | ||
entity="graphnet-team", | ||
save_dir=wandb_dir, | ||
log_model=True, | ||
) | ||
|
||
logger.info(f"features: {features}") | ||
logger.info(f"truth: {truth}") | ||
|
||
# Configuration | ||
config: Dict[str, Any] = { | ||
"path": path, | ||
"pulsemap": pulsemap, | ||
"batch_size": batch_size, | ||
"num_workers": num_workers, | ||
"target": target, | ||
"early_stopping_patience": early_stopping_patience, | ||
"fit": { | ||
"gpus": gpus, | ||
"max_epochs": max_epochs, | ||
}, | ||
} | ||
|
||
archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs") | ||
run_name = "dynedge_{}_example".format(config["target"]) | ||
if wandb: | ||
# Log configuration to W&B | ||
wandb_logger.experiment.config.update(config) | ||
|
||
# Define graph representation | ||
graph_definition = KNNGraph(detector=Prometheus()) | ||
|
||
( | ||
training_dataloader, | ||
validation_dataloader, | ||
) = make_train_validation_dataloader( | ||
db=config["path"], | ||
graph_definition=graph_definition, | ||
pulsemaps=config["pulsemap"], | ||
features=features, | ||
truth=truth, | ||
batch_size=config["batch_size"], | ||
num_workers=config["num_workers"], | ||
truth_table=truth_table, | ||
selection=None, | ||
) | ||
|
||
# Building model | ||
|
||
backbone = DynEdge( | ||
nb_inputs=graph_definition.nb_outputs, | ||
global_pooling_schemes=["min", "max", "mean", "sum"], | ||
) | ||
|
||
model = NormalizingFlow( | ||
graph_definition=graph_definition, | ||
backbone=backbone, | ||
optimizer_class=Adam, | ||
target_labels=config["target"], | ||
optimizer_kwargs={"lr": 1e-03, "eps": 1e-03}, | ||
scheduler_class=PiecewiseLinearLR, | ||
scheduler_kwargs={ | ||
"milestones": [ | ||
0, | ||
len(training_dataloader) / 2, | ||
len(training_dataloader) * config["fit"]["max_epochs"], | ||
], | ||
"factors": [1e-2, 1, 1e-02], | ||
}, | ||
scheduler_config={ | ||
"interval": "step", | ||
}, | ||
) | ||
|
||
# Training model | ||
model.fit( | ||
training_dataloader, | ||
validation_dataloader, | ||
early_stopping_patience=config["early_stopping_patience"], | ||
logger=wandb_logger if wandb else None, | ||
**config["fit"], | ||
) | ||
|
||
# Get predictions | ||
additional_attributes = model.target_labels | ||
assert isinstance(additional_attributes, list) # mypy | ||
|
||
results = model.predict_as_dataframe( | ||
validation_dataloader, | ||
additional_attributes=additional_attributes + ["event_no"], | ||
gpus=config["fit"]["gpus"], | ||
) | ||
|
||
# Save predictions and model to file | ||
db_name = path.split("/")[-1].split(".")[0] | ||
path = os.path.join(archive, db_name, run_name) | ||
logger.info(f"Writing results to {path}") | ||
os.makedirs(path, exist_ok=True) | ||
|
||
# Save results as .csv | ||
results.to_csv(f"{path}/results.csv") | ||
|
||
# Save full model (including weights) to .pth file - not version safe | ||
# Note: Models saved as .pth files in one version of graphnet | ||
# may not be compatible with a different version of graphnet. | ||
model.save(f"{path}/model.pth") | ||
|
||
# Save model config and state dict - Version safe save method. | ||
# This method of saving models is the safest way. | ||
model.save_state_dict(f"{path}/state_dict.pth") | ||
model.save_config(f"{path}/model_config.yml") | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
# Parse command-line arguments | ||
parser = ArgumentParser( | ||
description=""" | ||
Train conditional NormalizingFlow without the use of config files. | ||
""" | ||
) | ||
|
||
parser.add_argument( | ||
"--path", | ||
help="Path to dataset file (default: %(default)s)", | ||
default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db", | ||
) | ||
|
||
parser.add_argument( | ||
"--pulsemap", | ||
help="Name of pulsemap to use (default: %(default)s)", | ||
default="total", | ||
) | ||
|
||
parser.add_argument( | ||
"--target", | ||
help=( | ||
"Name of feature to use as regression target (default: " | ||
"%(default)s)" | ||
), | ||
default="total_energy", | ||
) | ||
|
||
parser.add_argument( | ||
"--truth-table", | ||
help="Name of truth table to be used (default: %(default)s)", | ||
default="mc_truth", | ||
) | ||
|
||
parser.with_standard_arguments( | ||
"gpus", | ||
("max-epochs", 1), | ||
"early-stopping-patience", | ||
("batch-size", 50), | ||
"num-workers", | ||
) | ||
|
||
parser.add_argument( | ||
"--wandb", | ||
action="store_true", | ||
help="If True, Weights & Biases are used to track the experiment.", | ||
) | ||
|
||
args, unknown = parser.parse_known_args() | ||
|
||
main( | ||
args.path, | ||
args.pulsemap, | ||
args.target, | ||
args.truth_table, | ||
args.gpus, | ||
args.max_epochs, | ||
args.early_stopping_patience, | ||
args.batch_size, | ||
args.num_workers, | ||
args.wandb, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.