Skip to content

Commit

Permalink
Merge branch 'graphnet-team:main' into fix_i3tray
Browse files Browse the repository at this point in the history
  • Loading branch information
pweigel authored Nov 22, 2024
2 parents d19fe03 + 89cef52 commit 63044f6
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions examples/04_training/01_train_dynedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,17 @@ def main(
# Log configuration to W&B
wandb_logger.experiment.config.update(config)

# Define graph representation
# Define graph/data representation, here the KNNGraph is used.
# The KNNGraph is a graph representation, which uses the
# KNNEdges edge definition with 8 neighbours as default.
# The graph representation is defined by the detector,
# in this case the Prometheus detector.
# The standard node definition is used, which is NodesAsPulses.
graph_definition = KNNGraph(detector=Prometheus())

# Use GraphNetDataModule to load in data
# Use GraphNetDataModule to load in data and create dataloaders
# The input here depends on the dataset being used,
# in this case the Prometheus dataset.
dm = GraphNeTDataModule(
dataset_reference=config["dataset_reference"],
dataset_args={
Expand All @@ -110,17 +117,28 @@ def main(

# Building model

# Define architecture of the backbone, in this example
# the DynEdge architecture is used.
# https://iopscience.iop.org/article/10.1088/1748-0221/17/11/P11003
backbone = DynEdge(
nb_inputs=graph_definition.nb_outputs,
global_pooling_schemes=["min", "max", "mean", "sum"],
)
# Define the task.
# Here an energy reconstruction, with a LogCoshLoss function.
# The target and prediction are transformed using the log10 function.
# When infering the prediction is transformed back to the
# original scale using 10^x.
task = EnergyReconstruction(
hidden_size=backbone.nb_outputs,
target_labels=config["target"],
loss_function=LogCoshLoss(),
transform_prediction_and_target=lambda x: torch.log10(x),
transform_inference=lambda x: torch.pow(10, x),
)
# Define the full model, which includes the backbone, task(s),
# along with typical machine learning options such as
# learning rate optimizers and schedulers.
model = StandardModel(
graph_definition=graph_definition,
backbone=backbone,
Expand Down

0 comments on commit 63044f6

Please sign in to comment.