Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train example comments #756

Merged
merged 4 commits into from
Nov 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions examples/04_training/01_train_dynedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,17 @@ def main(
# Log configuration to W&B
wandb_logger.experiment.config.update(config)

# Define graph representation
# Define graph/data representation, here the KNNGraph is used.
# The KNNGraph is a graph representation, which uses the
# KNNEdges edge definition with 8 neighbours as default.
# The graph representation is defined by the detector,
# in this case the Prometheus detector.
# The standard node definition is used, which is NodesAsPulses.
graph_definition = KNNGraph(detector=Prometheus())

# Use GraphNetDataModule to load in data
# Use GraphNetDataModule to load in data and create dataloaders
# The input here depends on the dataset being used,
# in this case the Prometheus dataset.
dm = GraphNeTDataModule(
dataset_reference=config["dataset_reference"],
dataset_args={
Expand All @@ -110,17 +117,28 @@ def main(

# Building model

# Define architecture of the backbone, in this example
# the DynEdge architecture is used.
# https://iopscience.iop.org/article/10.1088/1748-0221/17/11/P11003
backbone = DynEdge(
nb_inputs=graph_definition.nb_outputs,
global_pooling_schemes=["min", "max", "mean", "sum"],
)
# Define the task.
# Here an energy reconstruction, with a LogCoshLoss function.
# The target and prediction are transformed using the log10 function.
# When infering the prediction is transformed back to the
# original scale using 10^x.
task = EnergyReconstruction(
hidden_size=backbone.nb_outputs,
target_labels=config["target"],
loss_function=LogCoshLoss(),
transform_prediction_and_target=lambda x: torch.log10(x),
transform_inference=lambda x: torch.pow(10, x),
)
# Define the full model, which includes the backbone, task(s),
# along with typical machine learning options such as
# learning rate optimizers and schedulers.
model = StandardModel(
graph_definition=graph_definition,
backbone=backbone,
Expand Down
Loading