diff --git a/examples/04_training/01_train_dynedge.py b/examples/04_training/01_train_dynedge.py index 6ee6e0223..8ed4439c3 100644 --- a/examples/04_training/01_train_dynedge.py +++ b/examples/04_training/01_train_dynedge.py @@ -81,10 +81,17 @@ def main( # Log configuration to W&B wandb_logger.experiment.config.update(config) - # Define graph representation + # Define graph/data representation, here the KNNGraph is used. + # The KNNGraph is a graph representation, which uses the + # KNNEdges edge definition with 8 neighbours as default. + # The graph representation is defined by the detector, + # in this case the Prometheus detector. + # The standard node definition is used, which is NodesAsPulses. graph_definition = KNNGraph(detector=Prometheus()) - # Use GraphNetDataModule to load in data + # Use GraphNetDataModule to load in data and create dataloaders + # The input here depends on the dataset being used, + # in this case the Prometheus dataset. dm = GraphNeTDataModule( dataset_reference=config["dataset_reference"], dataset_args={ @@ -110,10 +117,18 @@ def main( # Building model + # Define architecture of the backbone, in this example + # the DynEdge architecture is used. + # https://iopscience.iop.org/article/10.1088/1748-0221/17/11/P11003 backbone = DynEdge( nb_inputs=graph_definition.nb_outputs, global_pooling_schemes=["min", "max", "mean", "sum"], ) + # Define the task. + # Here an energy reconstruction, with a LogCoshLoss function. + # The target and prediction are transformed using the log10 function. + # When infering the prediction is transformed back to the + # original scale using 10^x. task = EnergyReconstruction( hidden_size=backbone.nb_outputs, target_labels=config["target"], @@ -121,6 +136,9 @@ def main( transform_prediction_and_target=lambda x: torch.log10(x), transform_inference=lambda x: torch.pow(10, x), ) + # Define the full model, which includes the backbone, task(s), + # along with typical machine learning options such as + # learning rate optimizers and schedulers. model = StandardModel( graph_definition=graph_definition, backbone=backbone,