From bed86f21a0b5b183ab3a6f4103cbc58b84fce6b4 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 23 Jan 2024 14:46:32 +0900 Subject: [PATCH] add example script --- examples/04_training/05_train_RNN_TITO.py | 266 ++++++++++++++++++++++ src/graphnet/models/graphs/nodes/nodes.py | 29 ++- 2 files changed, 288 insertions(+), 7 deletions(-) create mode 100644 examples/04_training/05_train_RNN_TITO.py diff --git a/examples/04_training/05_train_RNN_TITO.py b/examples/04_training/05_train_RNN_TITO.py new file mode 100644 index 000000000..d3832bafb --- /dev/null +++ b/examples/04_training/05_train_RNN_TITO.py @@ -0,0 +1,266 @@ +"""Example of training RNN-TITO model. + +with time-series data. +""" + +import os +from typing import Any, Dict, List, Optional + +from pytorch_lightning.loggers import WandbLogger +from torch.optim.adam import Adam +from torch.optim.lr_scheduler import ReduceLROnPlateau + +from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR +from graphnet.data.constants import FEATURES, TRUTH +from graphnet.models import StandardModel +from graphnet.models.detector.prometheus import Prometheus +from graphnet.models.gnn import RNN_TITO +from graphnet.models.graphs import KNNGraph +from graphnet.models.graphs.nodes import NodeAsDOMTimeSeries +from graphnet.models.task.reconstruction import ( + DirectionReconstructionWithKappa, +) +from graphnet.training.labels import Direction +from graphnet.training.loss_functions import VonMisesFisher3DLoss +from graphnet.training.utils import make_train_validation_dataloader +from graphnet.utilities.argparse import ArgumentParser +from graphnet.utilities.logging import Logger + +# Constants +features = FEATURES.PROMETHEUS +truth = TRUTH.PROMETHEUS + + +def main( + path: str, + pulsemap: str, + target: str, + truth_table: str, + gpus: Optional[List[int]], + max_epochs: int, + early_stopping_patience: int, + batch_size: int, + num_workers: int, + wandb: bool = False, +) -> None: + """Run example.""" + # Construct Logger + logger = Logger() + + # Initialise Weights & Biases (W&B) run + if wandb: + # Make sure W&B output directory exists + wandb_dir = "./wandb/" + os.makedirs(wandb_dir, exist_ok=True) + wandb_logger = WandbLogger( + project="example-script", + entity="graphnet-team", + save_dir=wandb_dir, + log_model=True, + ) + + logger.info(f"features: {features}") + logger.info(f"truth: {truth}") + + # Configuration + config: Dict[str, Any] = { + "path": path, + "pulsemap": pulsemap, + "batch_size": batch_size, + "num_workers": num_workers, + "target": target, + "early_stopping_patience": early_stopping_patience, + "fit": { + "gpus": gpus, + "max_epochs": max_epochs, + }, + } + + graph_definition = KNNGraph( + detector=Prometheus(), + node_definition=NodeAsDOMTimeSeries( + keys=features, + id_columns=features[0:3], + time_column=features[-1], + charge_column="None", + ), + ) + archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_RNN_TITO_model") + run_name = "RNN_TITO_{}_example".format(config["target"]) + if wandb: + # Log configuration to W&B + wandb_logger.experiment.config.update(config) + + ( + training_dataloader, + validation_dataloader, + ) = make_train_validation_dataloader( + db=config["path"], + graph_definition=graph_definition, + selection=None, + pulsemaps=config["pulsemap"], + features=features, + truth=truth, + batch_size=config["batch_size"], + num_workers=config["num_workers"], + truth_table=truth_table, + index_column="event_no", + labels={ + "direction": Direction( + azimuth_key="injection_azimuth", zenith_key="injection_zenith" + ) + }, + ) + + # Building model + backbone = RNN_TITO( + nb_inputs=graph_definition.nb_outputs, + nb_neighbours=8, + RNN_layers=2, + RNN_hidden_size=64, + RNN_dropout=0.5, + features_subset=[0, 1, 2, 3], + dyntrans_layer_sizes=[(256, 256), (256, 256), (256, 256), (256, 256)], + post_processing_layer_sizes=[336, 256], + readout_layer_sizes=[256, 128], + global_pooling_schemes=["max"], + embedding_dim=0, + n_head=16, + use_global_features=True, + use_post_processing_layers=True, + ) + + task = DirectionReconstructionWithKappa( + hidden_size=backbone.nb_outputs, + target_labels=config["target"], + loss_function=VonMisesFisher3DLoss(), + ) + model = StandardModel( + graph_definition=graph_definition, + backbone=backbone, + tasks=[task], + optimizer_class=Adam, + optimizer_kwargs={"lr": 1e-03, "eps": 1e-03}, + scheduler_class=ReduceLROnPlateau, + scheduler_kwargs={ + "patience": config["early_stopping_patience"], + }, + scheduler_config={ + "frequency": 1, + "monitor": "val_loss", + }, + ) + + # Training model + + model.fit( + training_dataloader, + validation_dataloader, + early_stopping_patience=config["early_stopping_patience"], + logger=wandb_logger if wandb else None, + **config["fit"], + ) + + # Get predictions + additional_attributes = [ + "injection_zenith", + "injection_azimuth", + "event_no", + ] + prediction_columns = [ + config["target"][0] + "_x_pred", + config["target"][0] + "_y_pred", + config["target"][0] + "_z_pred", + config["target"][0] + "_kappa_pred", + ] + + assert isinstance(additional_attributes, list) # mypy + + results = model.predict_as_dataframe( + validation_dataloader, + additional_attributes=additional_attributes, + prediction_columns=prediction_columns, + gpus=config["fit"]["gpus"], + ) + + # Save predictions and model to file + db_name = path.split("/")[-1].split(".")[0] + path = os.path.join(archive, db_name, run_name) + logger.info(f"Writing results to {path}") + os.makedirs(path, exist_ok=True) + + # Save results as .csv + results.to_csv(f"{path}/results.csv") + + # Save full model (including weights) to .pth file - Not version proof + model.save(f"{path}/model.pth") + + # Save model config and state dict - Version safe save method. + model.save_state_dict(f"{path}/state_dict.pth") + model.save_config(f"{path}/model_config.yml") + + +if __name__ == "__main__": + + # Parse command-line arguments + parser = ArgumentParser( + description=""" +Train GNN model without the use of config files. +""" + ) + + parser.add_argument( + "--path", + help="Path to dataset file (default: %(default)s)", + default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db", + ) + + parser.add_argument( + "--pulsemap", + help="Name of pulsemap to use (default: %(default)s)", + default="total", + ) + + parser.add_argument( + "--target", + help=( + "Name of feature to use as regression target (default: " + "%(default)s)" + ), + default="direction", + ) + + parser.add_argument( + "--truth-table", + help="Name of truth table to be used (default: %(default)s)", + default="mc_truth", + ) + + parser.with_standard_arguments( + "gpus", + ("max-epochs", 1), + ("early-stopping-patience", 2), + ("batch-size", 16), + "num-workers", + ) + + parser.add_argument( + "--wandb", + action="store_true", + help="If True, Weights & Biases are used to track the experiment.", + ) + + args, unknown = parser.parse_known_args() + + main( + args.path, + args.pulsemap, + args.target, + args.truth_table, + args.gpus, + args.max_epochs, + args.early_stopping_patience, + args.batch_size, + args.num_workers, + args.wandb, + ) diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py index 4056f93cf..f81213a68 100644 --- a/src/graphnet/models/graphs/nodes/nodes.py +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -1,6 +1,6 @@ """Class(es) for building/connecting graphs.""" -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Union from abc import abstractmethod import torch @@ -216,7 +216,7 @@ def _construct_nodes(self, x: torch.Tensor) -> Data: return Data(x=torch.tensor(array)) -class NodeAsDOMTimeSeries: +class NodeAsDOMTimeSeries(NodeDefinition): """Represent each node as a DOM with time and charge time series data.""" def __init__( @@ -243,11 +243,21 @@ def __init__( max_activations: Maximum number of activations to include in the time series. """ self._keys = keys + super().__init__(input_feature_names=self._keys) self._id_columns = [self._keys.index(key) for key in id_columns] self._time_index = self._keys.index(time_column) - self._charge_index = self._keys.index(charge_column) + try: + self._charge_index: Optional[int] = self._keys.index(charge_column) + except ValueError: + self.warning( + "Charge column with name {} not found. Running without.".format( + charge_column + ) + ) + + self._charge_index = None + self._max_activations = max_activations - super().__init__() def _define_output_feature_names( self, input_feature_names: List[str] @@ -258,10 +268,15 @@ def _construct_nodes(self, x: torch.Tensor) -> Data: """Construct nodes from raw node features ´x´.""" # Cast to Numpy x = x.numpy() + # if there is no charge column add a dummy column of zeros with the same shape as the time column + if self._charge_index is None: + charge_index: int = len(self._keys) + x = np.insert(x, charge_index, np.zeros(x.shape[0]), axis=1) + # Sort by time x = x[x[:, self._time_index].argsort()] # Undo log10 scaling so we can sum charges - x[:, self._charge_index] = np.power(10, x[:, self._charge_index]) + x[:, charge_index] = np.power(10, x[:, charge_index]) # Shift time to start at 0 x[:, self._time_index] -= np.min(x[:, self._time_index]) # Group pulses on the same DOM @@ -279,10 +294,10 @@ def _construct_nodes(self, x: torch.Tensor) -> Data: counts = sort_this[:, unique_sensors.shape[1] :].flatten().astype(int) time_series = np.split( - x[:, [self._charge_index, self._time_index]], counts.cumsum()[:-1] + x[:, [charge_index, self._time_index]], counts.cumsum()[:-1] ) - # add total charge to unique dom features and apply inverse hyperbolic sine scaling + # add first time and total charge to unique dom features and apply inverse hyperbolic sine scaling time_charge = np.stack( [ (image[0, 1], np.arcsinh(5 * image[:, 0].sum()) / 5)