From 64f5b9be67cf67eaabeb4283a018b7df48f242a4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= <andreas.sogaard@gmail.com>
Date: Thu, 14 Sep 2023 09:02:04 +0000
Subject: [PATCH] =?UTF-8?q?Deploying=20to=20gh-pages=20from=20@=20graphnet?=
 =?UTF-8?q?-team/graphnet@6edf835e0f6c1e044d853c5d8b74a5c5ea50f99d=20?=
 =?UTF-8?q?=F0=9F=9A=80?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 _modules/graphnet/data/constants.html         |    2 +-
 _modules/graphnet/data/dataconverter.html     |    2 +-
 _modules/graphnet/data/dataloader.html        |  457 ++++++
 _modules/graphnet/data/dataset/dataset.html   | 1084 ++++++++++++++
 .../data/dataset/parquet/parquet_dataset.html |  500 +++++++
 .../data/dataset/sqlite/sqlite_dataset.html   |  515 +++++++
 .../sqlite/sqlite_dataset_perturbed.html      |  515 +++++++
 .../graphnet/data/extractors/i3extractor.html |    2 +-
 .../data/extractors/i3featureextractor.html   |    2 +-
 .../data/extractors/i3genericextractor.html   |    2 +-
 .../extractors/i3hybridrecoextractor.html     |    2 +-
 .../extractors/i3ntmuonlabelsextractor.html   |    2 +-
 .../data/extractors/i3particleextractor.html  |    2 +-
 .../data/extractors/i3pisaextractor.html      |    2 +-
 .../data/extractors/i3quesoextractor.html     |    2 +-
 .../data/extractors/i3retroextractor.html     |    2 +-
 .../data/extractors/i3splinempeextractor.html |    2 +-
 .../data/extractors/i3truthextractor.html     |    2 +-
 .../data/extractors/i3tumextractor.html       |    2 +-
 .../extractors/utilities/collections.html     |    2 +-
 .../data/extractors/utilities/frames.html     |    2 +-
 .../data/extractors/utilities/types.html      |    2 +-
 .../data/parquet/parquet_dataconverter.html   |    2 +-
 _modules/graphnet/data/pipeline.html          |  593 ++++++++
 .../data/sqlite/sqlite_dataconverter.html     |    2 +-
 .../data/sqlite/sqlite_utilities.html         |    2 +-
 .../data/utilities/parquet_to_sqlite.html     |    2 +-
 _modules/graphnet/data/utilities/random.html  |    2 +-
 .../utilities/string_selection_resolver.html  |    2 +-
 .../deployment/i3modules/graphnet_module.html |  817 +++++++++++
 _modules/graphnet/models/coarsening.html      |  711 +++++++++
 .../graphnet/models/components/layers.html    |  579 ++++++++
 _modules/graphnet/models/components/pool.html |  656 +++++++++
 .../graphnet/models/detector/detector.html    |  421 ++++++
 .../graphnet/models/detector/icecube.html     |  528 +++++++
 .../graphnet/models/detector/prometheus.html  |  395 +++++
 _modules/graphnet/models/gnn/convnet.html     |  486 +++++++
 _modules/graphnet/models/gnn/dynedge.html     |  693 +++++++++
 .../graphnet/models/gnn/dynedge_jinst.html    |  521 +++++++
 .../models/gnn/dynedge_kaggle_tito.html       |  618 ++++++++
 _modules/graphnet/models/gnn/gnn.html         |  403 +++++
 .../graphnet/models/graphs/edges/edges.html   |  563 +++++++
 .../models/graphs/graph_definition.html       |  634 ++++++++
 _modules/graphnet/models/graphs/graphs.html   |  410 ++++++
 .../graphnet/models/graphs/nodes/nodes.html   |  444 ++++++
 _modules/graphnet/models/model.html           |  726 ++++++++++
 _modules/graphnet/models/standard_model.html  |  604 ++++++++
 .../graphnet/models/task/classification.html  |  411 ++++++
 .../graphnet/models/task/reconstruction.html  |  609 ++++++++
 _modules/graphnet/models/task/task.html       |  688 +++++++++
 _modules/graphnet/models/utils.html           |  430 ++++++
 _modules/graphnet/pisa/fitting.html           |    2 +-
 _modules/graphnet/pisa/plotting.html          |    2 +-
 _modules/graphnet/training/callbacks.html     |  544 +++++++
 _modules/graphnet/training/labels.html        |  436 ++++++
 .../graphnet/training/loss_functions.html     |  859 +++++++++++
 _modules/graphnet/training/utils.html         |  656 +++++++++
 .../graphnet/training/weight_fitting.html     |    2 +-
 _modules/graphnet/utilities/argparse.html     |    2 +-
 .../utilities/config/base_config.html         |  449 ++++++
 .../utilities/config/configurable.html        |  408 ++++++
 .../utilities/config/dataset_config.html      |  585 ++++++++
 .../utilities/config/model_config.html        |  654 +++++++++
 .../graphnet/utilities/config/parsing.html    |  475 ++++++
 .../utilities/config/training_config.html     |  378 +++++
 _modules/graphnet/utilities/filesys.html      |    2 +-
 _modules/graphnet/utilities/imports.html      |    2 +-
 _modules/graphnet/utilities/logging.html      |    2 +-
 _modules/graphnet/utilities/maths.html        |  371 +++++
 _modules/index.html                           |   41 +-
 about.html                                    |    2 +-
 api/graphnet.constants.html                   |    2 +-
 api/graphnet.data.constants.html              |    2 +-
 api/graphnet.data.dataconverter.html          |    2 +-
 api/graphnet.data.dataloader.html             |  129 +-
 api/graphnet.data.dataset.dataset.html        |  330 ++++-
 api/graphnet.data.dataset.html                |   16 +-
 api/graphnet.data.dataset.parquet.html        |   12 +-
 ....data.dataset.parquet.parquet_dataset.html |  116 +-
 api/graphnet.data.dataset.sqlite.html         |   17 +-
 ...et.data.dataset.sqlite.sqlite_dataset.html |  116 +-
 ...taset.sqlite.sqlite_dataset_perturbed.html |   84 +-
 api/graphnet.data.extractors.html             |    2 +-
 api/graphnet.data.extractors.i3extractor.html |    2 +-
 ...et.data.extractors.i3featureextractor.html |    2 +-
 ...et.data.extractors.i3genericextractor.html |    2 +-
 ...data.extractors.i3hybridrecoextractor.html |    2 +-
 ...ta.extractors.i3ntmuonlabelsextractor.html |    2 +-
 ...t.data.extractors.i3particleextractor.html |    2 +-
 ...phnet.data.extractors.i3pisaextractor.html |    2 +-
 ...hnet.data.extractors.i3quesoextractor.html |    2 +-
 ...hnet.data.extractors.i3retroextractor.html |    2 +-
 ....data.extractors.i3splinempeextractor.html |    2 +-
 ...hnet.data.extractors.i3truthextractor.html |    2 +-
 ...aphnet.data.extractors.i3tumextractor.html |    2 +-
 ...data.extractors.utilities.collections.html |    2 +-
 ...hnet.data.extractors.utilities.frames.html |    2 +-
 api/graphnet.data.extractors.utilities.html   |    2 +-
 ...phnet.data.extractors.utilities.types.html |    2 +-
 api/graphnet.data.html                        |   14 +-
 api/graphnet.data.parquet.html                |    2 +-
 ...et.data.parquet.parquet_dataconverter.html |    2 +-
 api/graphnet.data.pipeline.html               |   55 +-
 api/graphnet.data.sqlite.html                 |    2 +-
 ...hnet.data.sqlite.sqlite_dataconverter.html |    2 +-
 ...graphnet.data.sqlite.sqlite_utilities.html |    2 +-
 api/graphnet.data.utilities.html              |    2 +-
 ...hnet.data.utilities.parquet_to_sqlite.html |    2 +-
 api/graphnet.data.utilities.random.html       |    2 +-
 ...a.utilities.string_selection_resolver.html |    4 +-
 api/graphnet.deployment.html                  |    2 +-
 ...raphnet.deployment.i3modules.deployer.html |    2 +-
 ....deployment.i3modules.graphnet_module.html |  135 +-
 api/graphnet.deployment.i3modules.html        |    9 +-
 api/graphnet.html                             |    2 +-
 api/graphnet.models.coarsening.html           |  226 ++-
 api/graphnet.models.components.html           |   28 +-
 api/graphnet.models.components.layers.html    |  253 +++-
 api/graphnet.models.components.pool.html      |  339 ++++-
 api/graphnet.models.detector.detector.html    |   89 +-
 api/graphnet.models.detector.html             |   25 +-
 api/graphnet.models.detector.icecube.html     |  197 ++-
 api/graphnet.models.detector.prometheus.html  |   62 +-
 api/graphnet.models.gnn.convnet.html          |   76 +-
 api/graphnet.models.gnn.dynedge.html          |   98 +-
 api/graphnet.models.gnn.dynedge_jinst.html    |   73 +-
 ...aphnet.models.gnn.dynedge_kaggle_tito.html |   83 +-
 api/graphnet.models.gnn.gnn.html              |  103 +-
 api/graphnet.models.gnn.html                  |   32 +-
 api/graphnet.models.graphs.edges.edges.html   |  169 ++-
 api/graphnet.models.graphs.edges.html         |   18 +-
 ...aphnet.models.graphs.graph_definition.html |   93 +-
 api/graphnet.models.graphs.graphs.html        |   50 +-
 api/graphnet.models.graphs.html               |   20 +-
 api/graphnet.models.graphs.nodes.html         |   16 +-
 api/graphnet.models.graphs.nodes.nodes.html   |  132 +-
 api/graphnet.models.html                      |   39 +-
 api/graphnet.models.model.html                |  305 +++-
 api/graphnet.models.standard_model.html       |  348 ++++-
 api/graphnet.models.task.classification.html  |  262 +++-
 api/graphnet.models.task.html                 |   36 +-
 api/graphnet.models.task.reconstruction.html  | 1290 ++++++++++++++++-
 api/graphnet.models.task.task.html            |  302 +++-
 api/graphnet.models.utils.html                |  110 +-
 api/graphnet.pisa.fitting.html                |    2 +-
 api/graphnet.pisa.html                        |    2 +-
 api/graphnet.pisa.plotting.html               |    2 +-
 api/graphnet.training.callbacks.html          |  277 +++-
 api/graphnet.training.html                    |   38 +-
 api/graphnet.training.labels.html             |   93 +-
 api/graphnet.training.loss_functions.html     |  488 ++++++-
 api/graphnet.training.utils.html              |  191 ++-
 api/graphnet.training.weight_fitting.html     |    2 +-
 api/graphnet.utilities.argparse.html          |    2 +-
 ...graphnet.utilities.config.base_config.html |  177 ++-
 ...raphnet.utilities.config.configurable.html |  105 +-
 ...phnet.utilities.config.dataset_config.html |  438 +++++-
 api/graphnet.utilities.config.html            |   45 +-
 ...raphnet.utilities.config.model_config.html |  177 ++-
 api/graphnet.utilities.config.parsing.html    |  165 ++-
 ...hnet.utilities.config.training_config.html |  147 +-
 api/graphnet.utilities.decorators.html        |    2 +-
 api/graphnet.utilities.filesys.html           |    2 +-
 api/graphnet.utilities.html                   |    7 +-
 api/graphnet.utilities.imports.html           |    2 +-
 api/graphnet.utilities.logging.html           |    2 +-
 api/graphnet.utilities.maths.html             |   41 +-
 api/modules.html                              |    2 +-
 contribute.html                               |    2 +-
 genindex.html                                 | 1174 ++++++++++++++-
 index.html                                    |    2 +-
 install.html                                  |    2 +-
 objects.inv                                   |  Bin 3520 -> 6210 bytes
 py-modindex.html                              |  257 +++-
 search.html                                   |    2 +-
 searchindex.js                                |    2 +-
 sitemap.xml                                   |    2 +-
 177 files changed, 31425 insertions(+), 329 deletions(-)
 create mode 100644 _modules/graphnet/data/dataloader.html
 create mode 100644 _modules/graphnet/data/dataset/dataset.html
 create mode 100644 _modules/graphnet/data/dataset/parquet/parquet_dataset.html
 create mode 100644 _modules/graphnet/data/dataset/sqlite/sqlite_dataset.html
 create mode 100644 _modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html
 create mode 100644 _modules/graphnet/data/pipeline.html
 create mode 100644 _modules/graphnet/deployment/i3modules/graphnet_module.html
 create mode 100644 _modules/graphnet/models/coarsening.html
 create mode 100644 _modules/graphnet/models/components/layers.html
 create mode 100644 _modules/graphnet/models/components/pool.html
 create mode 100644 _modules/graphnet/models/detector/detector.html
 create mode 100644 _modules/graphnet/models/detector/icecube.html
 create mode 100644 _modules/graphnet/models/detector/prometheus.html
 create mode 100644 _modules/graphnet/models/gnn/convnet.html
 create mode 100644 _modules/graphnet/models/gnn/dynedge.html
 create mode 100644 _modules/graphnet/models/gnn/dynedge_jinst.html
 create mode 100644 _modules/graphnet/models/gnn/dynedge_kaggle_tito.html
 create mode 100644 _modules/graphnet/models/gnn/gnn.html
 create mode 100644 _modules/graphnet/models/graphs/edges/edges.html
 create mode 100644 _modules/graphnet/models/graphs/graph_definition.html
 create mode 100644 _modules/graphnet/models/graphs/graphs.html
 create mode 100644 _modules/graphnet/models/graphs/nodes/nodes.html
 create mode 100644 _modules/graphnet/models/model.html
 create mode 100644 _modules/graphnet/models/standard_model.html
 create mode 100644 _modules/graphnet/models/task/classification.html
 create mode 100644 _modules/graphnet/models/task/reconstruction.html
 create mode 100644 _modules/graphnet/models/task/task.html
 create mode 100644 _modules/graphnet/models/utils.html
 create mode 100644 _modules/graphnet/training/callbacks.html
 create mode 100644 _modules/graphnet/training/labels.html
 create mode 100644 _modules/graphnet/training/loss_functions.html
 create mode 100644 _modules/graphnet/training/utils.html
 create mode 100644 _modules/graphnet/utilities/config/base_config.html
 create mode 100644 _modules/graphnet/utilities/config/configurable.html
 create mode 100644 _modules/graphnet/utilities/config/dataset_config.html
 create mode 100644 _modules/graphnet/utilities/config/model_config.html
 create mode 100644 _modules/graphnet/utilities/config/parsing.html
 create mode 100644 _modules/graphnet/utilities/config/training_config.html
 create mode 100644 _modules/graphnet/utilities/maths.html

diff --git a/_modules/graphnet/data/constants.html b/_modules/graphnet/data/constants.html
index cd439b37c..b6526af39 100644
--- a/_modules/graphnet/data/constants.html
+++ b/_modules/graphnet/data/constants.html
@@ -436,7 +436,7 @@ <h1 id="modules-graphnet-data-constants--page-root">Source code for graphnet.dat
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/dataconverter.html b/_modules/graphnet/data/dataconverter.html
index aa78de84b..9be52a7d9 100644
--- a/_modules/graphnet/data/dataconverter.html
+++ b/_modules/graphnet/data/dataconverter.html
@@ -938,7 +938,7 @@ <h1 id="modules-graphnet-data-dataconverter--page-root">Source code for graphnet
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/dataloader.html b/_modules/graphnet/data/dataloader.html
new file mode 100644
index 000000000..4de320872
--- /dev/null
+++ b/_modules/graphnet/data/dataloader.html
@@ -0,0 +1,457 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.data.dataloader &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" />
+    <script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../about.html" />
+    <link rel="index" title="Index" href="../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/data/dataloader" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.data.dataloader </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../"versions.json"",
+        target_loc = "../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-data-dataloader--page-root">Source code for graphnet.data.dataloader</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Base `Dataloader` class(es) used in `graphnet`."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Union</span>
+
+<span class="kn">import</span> <span class="nn">torch.utils.data</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Batch</span><span class="p">,</span> <span class="n">Data</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.data.dataset</span> <span class="kn">import</span> <span class="n">Dataset</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">DatasetConfig</span>
+
+
+<div class="viewcode-block" id="collate_fn">
+<a class="viewcode-back" href="../../../api/graphnet.data.dataloader.html#graphnet.data.dataloader.collate_fn">[docs]</a>
+<span class="k">def</span> <span class="nf">collate_fn</span><span class="p">(</span><span class="n">graphs</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Data</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Batch</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Remove graphs with less than two DOM hits.</span>
+
+<span class="sd">    Should not occur in "production.</span>
+<span class="sd">    """</span>
+    <span class="n">graphs</span> <span class="o">=</span> <span class="p">[</span><span class="n">g</span> <span class="k">for</span> <span class="n">g</span> <span class="ow">in</span> <span class="n">graphs</span> <span class="k">if</span> <span class="n">g</span><span class="o">.</span><span class="n">n_pulses</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">]</span>
+    <span class="k">return</span> <span class="n">Batch</span><span class="o">.</span><span class="n">from_data_list</span><span class="p">(</span><span class="n">graphs</span><span class="p">)</span></div>
+
+
+
+<div class="viewcode-block" id="do_shuffle">
+<a class="viewcode-back" href="../../../api/graphnet.data.dataloader.html#graphnet.data.dataloader.do_shuffle">[docs]</a>
+<span class="k">def</span> <span class="nf">do_shuffle</span><span class="p">(</span><span class="n">selection_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Check whether to shuffle selection with name `selection_name`."""</span>
+    <span class="k">return</span> <span class="s2">"train"</span> <span class="ow">in</span> <span class="n">selection_name</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span></div>
+
+
+
+<div class="viewcode-block" id="DataLoader">
+<a class="viewcode-back" href="../../../api/graphnet.data.dataloader.html#graphnet.data.dataloader.DataLoader">[docs]</a>
+<span class="k">class</span> <span class="nc">DataLoader</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Class for loading data from a `Dataset`."""</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">dataset</span><span class="p">:</span> <span class="n">Dataset</span><span class="p">,</span>
+        <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
+        <span class="n">shuffle</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">num_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
+        <span class="n">persistent_workers</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
+        <span class="n">collate_fn</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="n">collate_fn</span><span class="p">,</span>
+        <span class="n">prefetch_factor</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span>
+        <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct `DataLoader`."""</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
+            <span class="n">dataset</span><span class="p">,</span>
+            <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
+            <span class="n">shuffle</span><span class="o">=</span><span class="n">shuffle</span><span class="p">,</span>
+            <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
+            <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">,</span>
+            <span class="n">persistent_workers</span><span class="o">=</span><span class="n">persistent_workers</span><span class="p">,</span>
+            <span class="n">prefetch_factor</span><span class="o">=</span><span class="n">prefetch_factor</span><span class="p">,</span>
+            <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
+        <span class="p">)</span>
+
+<div class="viewcode-block" id="DataLoader.from_dataset_config">
+<a class="viewcode-back" href="../../../api/graphnet.data.dataloader.html#graphnet.data.dataloader.DataLoader.from_dataset_config">[docs]</a>
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">from_dataset_config</span><span class="p">(</span>
+        <span class="bp">cls</span><span class="p">,</span>
+        <span class="n">config</span><span class="p">:</span> <span class="n">DatasetConfig</span><span class="p">,</span>
+        <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="s2">"DataLoader"</span><span class="p">,</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s2">"DataLoader"</span><span class="p">]]:</span>
+<span class="w">        </span><span class="sd">"""Construct `DataLoader`s based on selections in `DatasetConfig`."""</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">selection</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
+            <span class="k">assert</span> <span class="s2">"shuffle"</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">,</span> <span class="p">(</span>
+                <span class="s2">"When passing a `DatasetConfig` with multiple selections, "</span>
+                <span class="s2">"`shuffle` is automatically inferred from the selection name, "</span>
+                <span class="s2">"and thus should not specified as an argument."</span>
+            <span class="p">)</span>
+            <span class="n">datasets</span> <span class="o">=</span> <span class="n">Dataset</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
+            <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">datasets</span><span class="p">,</span> <span class="nb">dict</span><span class="p">)</span>
+            <span class="n">data_loaders</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">DataLoader</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
+            <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">dataset</span> <span class="ow">in</span> <span class="n">datasets</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
+                <span class="n">data_loaders</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span>
+                    <span class="n">dataset</span><span class="p">,</span>
+                    <span class="n">shuffle</span><span class="o">=</span><span class="n">do_shuffle</span><span class="p">(</span><span class="n">name</span><span class="p">),</span>
+                    <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
+                <span class="p">)</span>
+
+            <span class="k">return</span> <span class="n">data_loaders</span>
+
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">assert</span> <span class="s2">"shuffle"</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">,</span> <span class="p">(</span>
+                <span class="s2">"When passing a `DatasetConfig` with a single selections, you "</span>
+                <span class="s2">"need to specify `shuffle` as an argument."</span>
+            <span class="p">)</span>
+            <span class="n">dataset</span> <span class="o">=</span> <span class="n">Dataset</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
+            <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">Dataset</span><span class="p">)</span>
+            <span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
+</div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/data/dataset/dataset.html b/_modules/graphnet/data/dataset/dataset.html
new file mode 100644
index 000000000..f2ee86f4d
--- /dev/null
+++ b/_modules/graphnet/data/dataset/dataset.html
@@ -0,0 +1,1084 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.data.dataset.dataset &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/data/dataset/dataset" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.data.dataset.dataset </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-data-dataset-dataset--page-root">Source code for graphnet.data.dataset.dataset</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Base :py:class:`Dataset` class(es) used in GraphNeT."""</span>
+
+<span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span>
+<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">cast</span><span class="p">,</span>
+    <span class="n">Any</span><span class="p">,</span>
+    <span class="n">Callable</span><span class="p">,</span>
+    <span class="n">Dict</span><span class="p">,</span>
+    <span class="n">List</span><span class="p">,</span>
+    <span class="n">Optional</span><span class="p">,</span>
+    <span class="n">Tuple</span><span class="p">,</span>
+    <span class="n">Union</span><span class="p">,</span>
+    <span class="n">Iterable</span><span class="p">,</span>
+    <span class="n">Type</span><span class="p">,</span>
+<span class="p">)</span>
+
+<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.constants</span> <span class="kn">import</span> <span class="n">GRAPHNET_ROOT_DIR</span>
+<span class="kn">from</span> <span class="nn">graphnet.data.utilities.string_selection_resolver</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">StringSelectionResolver</span><span class="p">,</span>
+<span class="p">)</span>
+<span class="kn">from</span> <span class="nn">graphnet.training.labels</span> <span class="kn">import</span> <span class="n">Label</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">Configurable</span><span class="p">,</span>
+    <span class="n">DatasetConfig</span><span class="p">,</span>
+    <span class="n">save_dataset_config</span><span class="p">,</span>
+<span class="p">)</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config.parsing</span> <span class="kn">import</span> <span class="n">traverse_and_apply</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.graphs</span> <span class="kn">import</span> <span class="n">GraphDefinition</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config.parsing</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">get_all_grapnet_classes</span><span class="p">,</span>
+<span class="p">)</span>
+
+
+<div class="viewcode-block" id="ColumnMissingException">
+<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.ColumnMissingException">[docs]</a>
+<span class="k">class</span> <span class="nc">ColumnMissingException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Exception to indicate a missing column in a dataset."""</span></div>
+
+
+
+<div class="viewcode-block" id="load_module">
+<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.load_module">[docs]</a>
+<span class="k">def</span> <span class="nf">load_module</span><span class="p">(</span><span class="n">class_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Type</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Load graphnet module from string name.</span>
+
+<span class="sd">    Args:</span>
+<span class="sd">        class_name: name of class</span>
+
+<span class="sd">    Returns:</span>
+<span class="sd">        graphnet module.</span>
+<span class="sd">    """</span>
+    <span class="c1"># Get a lookup for all classes in `graphnet`</span>
+    <span class="kn">import</span> <span class="nn">graphnet.data</span>
+    <span class="kn">import</span> <span class="nn">graphnet.models</span>
+    <span class="kn">import</span> <span class="nn">graphnet.training</span>
+
+    <span class="n">namespace_classes</span> <span class="o">=</span> <span class="n">get_all_grapnet_classes</span><span class="p">(</span>
+        <span class="n">graphnet</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">graphnet</span><span class="o">.</span><span class="n">models</span><span class="p">,</span> <span class="n">graphnet</span><span class="o">.</span><span class="n">training</span>
+    <span class="p">)</span>
+    <span class="k">return</span> <span class="n">namespace_classes</span><span class="p">[</span><span class="n">class_name</span><span class="p">]</span></div>
+
+
+
+<div class="viewcode-block" id="parse_graph_definition">
+<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.parse_graph_definition">[docs]</a>
+<span class="k">def</span> <span class="nf">parse_graph_definition</span><span class="p">(</span><span class="n">cfg</span><span class="p">:</span> <span class="nb">dict</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">GraphDefinition</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Construct GraphDefinition from DatasetConfig."""</span>
+    <span class="k">assert</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"graph_definition"</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
+
+    <span class="n">args</span> <span class="o">=</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"graph_definition"</span><span class="p">][</span><span class="s2">"arguments"</span><span class="p">]</span>
+    <span class="n">classes</span> <span class="o">=</span> <span class="p">{}</span>
+    <span class="k">for</span> <span class="n">arg</span> <span class="ow">in</span> <span class="n">args</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">args</span><span class="p">[</span><span class="n">arg</span><span class="p">],</span> <span class="nb">dict</span><span class="p">):</span>
+            <span class="k">if</span> <span class="s2">"class_name"</span> <span class="ow">in</span> <span class="n">args</span><span class="p">[</span><span class="n">arg</span><span class="p">]</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+                <span class="n">classes</span><span class="p">[</span><span class="n">arg</span><span class="p">]</span> <span class="o">=</span> <span class="n">load_module</span><span class="p">(</span><span class="n">args</span><span class="p">[</span><span class="n">arg</span><span class="p">][</span><span class="s2">"class_name"</span><span class="p">])(</span>
+                    <span class="o">**</span><span class="n">args</span><span class="p">[</span><span class="n">arg</span><span class="p">][</span><span class="s2">"arguments"</span><span class="p">]</span>
+                <span class="p">)</span>
+        <span class="k">if</span> <span class="n">arg</span> <span class="o">==</span> <span class="s2">"dtype"</span><span class="p">:</span>
+            <span class="n">args</span><span class="p">[</span><span class="n">arg</span><span class="p">]</span> <span class="o">=</span> <span class="nb">eval</span><span class="p">(</span><span class="n">args</span><span class="p">[</span><span class="n">arg</span><span class="p">])</span>  <span class="c1"># converts string to class</span>
+
+    <span class="n">new_cfg</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">args</span><span class="p">)</span>
+    <span class="n">new_cfg</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">classes</span><span class="p">)</span>
+    <span class="n">graph_definition</span> <span class="o">=</span> <span class="n">load_module</span><span class="p">(</span><span class="n">cfg</span><span class="p">[</span><span class="s2">"graph_definition"</span><span class="p">][</span><span class="s2">"class_name"</span><span class="p">])(</span>
+        <span class="o">**</span><span class="n">new_cfg</span>
+    <span class="p">)</span>
+    <span class="k">return</span> <span class="n">graph_definition</span></div>
+
+
+
+<div class="viewcode-block" id="Dataset">
+<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset">[docs]</a>
+<span class="k">class</span> <span class="nc">Dataset</span><span class="p">(</span><span class="n">Logger</span><span class="p">,</span> <span class="n">Configurable</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="p">,</span> <span class="n">ABC</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Base Dataset class for reading from any intermediate file format."""</span>
+
+    <span class="c1"># Class method(s)</span>
+<div class="viewcode-block" id="Dataset.from_config">
+<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.from_config">[docs]</a>
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">from_config</span><span class="p">(</span>  <span class="c1"># type: ignore[override]</span>
+        <span class="bp">cls</span><span class="p">,</span>
+        <span class="n">source</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">DatasetConfig</span><span class="p">,</span> <span class="nb">str</span><span class="p">],</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span>
+        <span class="s2">"Dataset"</span><span class="p">,</span>
+        <span class="s2">"EnsembleDataset"</span><span class="p">,</span>
+        <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s2">"Dataset"</span><span class="p">],</span>
+        <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s2">"EnsembleDataset"</span><span class="p">],</span>
+    <span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Construct `Dataset` instance from `source` configuration."""</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">source</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
+            <span class="n">source</span> <span class="o">=</span> <span class="n">DatasetConfig</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">source</span><span class="p">)</span>
+
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">source</span><span class="p">,</span> <span class="n">DatasetConfig</span><span class="p">),</span> <span class="p">(</span>
+            <span class="sa">f</span><span class="s2">"Argument `source` of type (</span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">source</span><span class="p">)</span><span class="si">}</span><span class="s2">) is not a "</span>
+            <span class="s2">"`DatasetConfig`"</span>
+        <span class="p">)</span>
+
+        <span class="k">assert</span> <span class="p">(</span>
+            <span class="s2">"graph_definition"</span> <span class="ow">in</span> <span class="n">source</span><span class="o">.</span><span class="n">dict</span><span class="p">()</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
+        <span class="p">),</span> <span class="s2">"`DatasetConfig` incompatible with current GraphNeT version."</span>
+
+        <span class="c1"># Parse set of `selection``.</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">source</span><span class="o">.</span><span class="n">selection</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
+            <span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_construct_datasets_from_dict</span><span class="p">(</span><span class="n">source</span><span class="p">)</span>
+        <span class="k">elif</span> <span class="p">(</span>
+            <span class="nb">isinstance</span><span class="p">(</span><span class="n">source</span><span class="o">.</span><span class="n">selection</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
+            <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">source</span><span class="o">.</span><span class="n">selection</span><span class="p">)</span>
+            <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">source</span><span class="o">.</span><span class="n">selection</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nb">str</span><span class="p">)</span>
+        <span class="p">):</span>
+            <span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_construct_dataset_from_list_of_strings</span><span class="p">(</span><span class="n">source</span><span class="p">)</span>
+
+        <span class="n">cfg</span> <span class="o">=</span> <span class="n">source</span><span class="o">.</span><span class="n">dict</span><span class="p">()</span>
+        <span class="k">if</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"graph_definition"</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">cfg</span><span class="p">[</span><span class="s2">"graph_definition"</span><span class="p">]</span> <span class="o">=</span> <span class="n">parse_graph_definition</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">source</span><span class="o">.</span><span class="n">_dataset_class</span><span class="p">(</span><span class="o">**</span><span class="n">cfg</span><span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="Dataset.concatenate">
+<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.concatenate">[docs]</a>
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">concatenate</span><span class="p">(</span>
+        <span class="bp">cls</span><span class="p">,</span>
+        <span class="n">datasets</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s2">"Dataset"</span><span class="p">],</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">"EnsembleDataset"</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Concatenate multiple `Dataset`s into one instance."""</span>
+        <span class="k">return</span> <span class="n">EnsembleDataset</span><span class="p">(</span><span class="n">datasets</span><span class="p">)</span></div>
+
+
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">_construct_datasets_from_dict</span><span class="p">(</span>
+        <span class="bp">cls</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">DatasetConfig</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s2">"Dataset"</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Construct `Dataset` for each entry in dict `self.selection`."""</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">selection</span><span class="p">,</span> <span class="nb">dict</span><span class="p">)</span>
+        <span class="n">datasets</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s2">"Dataset"</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
+        <span class="n">selections</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">]]</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">selection</span><span class="p">)</span>
+        <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">selection</span> <span class="ow">in</span> <span class="n">selections</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
+            <span class="n">config</span><span class="o">.</span><span class="n">selection</span> <span class="o">=</span> <span class="n">selection</span>
+            <span class="n">dataset</span> <span class="o">=</span> <span class="n">Dataset</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
+            <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="p">(</span><span class="n">Dataset</span><span class="p">,</span> <span class="n">EnsembleDataset</span><span class="p">))</span>
+            <span class="n">datasets</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">dataset</span>
+
+        <span class="c1"># Reset `selections`.</span>
+        <span class="n">config</span><span class="o">.</span><span class="n">selection</span> <span class="o">=</span> <span class="n">selections</span>
+
+        <span class="k">return</span> <span class="n">datasets</span>
+
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">_construct_dataset_from_list_of_strings</span><span class="p">(</span>
+        <span class="bp">cls</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">DatasetConfig</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">"Dataset"</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct `Dataset` for each entry in list `self.selection`."""</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">selection</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
+        <span class="n">datasets</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s2">"Dataset"</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
+        <span class="n">selections</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="n">config</span><span class="o">.</span><span class="n">selection</span><span class="p">))</span>
+        <span class="k">for</span> <span class="n">selection</span> <span class="ow">in</span> <span class="n">selections</span><span class="p">:</span>
+            <span class="n">config</span><span class="o">.</span><span class="n">selection</span> <span class="o">=</span> <span class="n">selection</span>
+            <span class="n">dataset</span> <span class="o">=</span> <span class="n">Dataset</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
+            <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">Dataset</span><span class="p">)</span>
+            <span class="n">datasets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
+
+        <span class="c1"># Reset `selections`.</span>
+        <span class="n">config</span><span class="o">.</span><span class="n">selection</span> <span class="o">=</span> <span class="n">selections</span>
+
+        <span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">datasets</span><span class="p">)</span>
+
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">_resolve_graphnet_paths</span><span class="p">(</span>
+        <span class="bp">cls</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]:</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
+            <span class="k">return</span> <span class="p">[</span><span class="n">cast</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_resolve_graphnet_paths</span><span class="p">(</span><span class="n">p</span><span class="p">))</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">path</span><span class="p">]</span>
+
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span>
+        <span class="k">return</span> <span class="p">(</span>
+            <span class="n">path</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"$graphnet"</span><span class="p">,</span> <span class="n">GRAPHNET_ROOT_DIR</span><span class="p">)</span>
+            <span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"$GRAPHNET"</span><span class="p">,</span> <span class="n">GRAPHNET_ROOT_DIR</span><span class="p">)</span>
+            <span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"$</span><span class="si">{graphnet}</span><span class="s2">"</span><span class="p">,</span> <span class="n">GRAPHNET_ROOT_DIR</span><span class="p">)</span>
+            <span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"$</span><span class="si">{GRAPHNET}</span><span class="s2">"</span><span class="p">,</span> <span class="n">GRAPHNET_ROOT_DIR</span><span class="p">)</span>
+        <span class="p">)</span>
+
+    <span class="nd">@save_dataset_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">path</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span>
+        <span class="n">graph_definition</span><span class="p">:</span> <span class="n">GraphDefinition</span><span class="p">,</span>
+        <span class="n">pulsemaps</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span>
+        <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+        <span class="n">truth</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+        <span class="o">*</span><span class="p">,</span>
+        <span class="n">node_truth</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">index_column</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"event_no"</span><span class="p">,</span>
+        <span class="n">truth_table</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"truth"</span><span class="p">,</span>
+        <span class="n">node_truth_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">string_selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">dtype</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
+        <span class="n">loss_weight_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">loss_weight_column</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">loss_weight_default_value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct Dataset.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            path: Path to the file(s) from which this `Dataset` should read.</span>
+<span class="sd">            pulsemaps: Name(s) of the pulse map series that should be used to</span>
+<span class="sd">                construct the nodes on the individual graph objects, and their</span>
+<span class="sd">                features. Multiple pulse series maps can be used, e.g., when</span>
+<span class="sd">                different DOM types are stored in different maps.</span>
+<span class="sd">            features: List of columns in the input files that should be used as</span>
+<span class="sd">                node features on the graph objects.</span>
+<span class="sd">            truth: List of event-level columns in the input files that should</span>
+<span class="sd">                be used added as attributes on the  graph objects.</span>
+<span class="sd">            node_truth: List of node-level columns in the input files that</span>
+<span class="sd">                should be used added as attributes on the graph objects.</span>
+<span class="sd">            index_column: Name of the column in the input files that contains</span>
+<span class="sd">                unique indicies to identify and map events across tables.</span>
+<span class="sd">            truth_table: Name of the table containing event-level truth</span>
+<span class="sd">                information.</span>
+<span class="sd">            node_truth_table: Name of the table containing node-level truth</span>
+<span class="sd">                information.</span>
+<span class="sd">            string_selection: Subset of strings for which data should be read</span>
+<span class="sd">                and used to construct graph objects. Defaults to None, meaning</span>
+<span class="sd">                all strings for which data exists are used.</span>
+<span class="sd">            selection: The events that should be read. This can be given either</span>
+<span class="sd">                as list of indicies (in `index_column`); or a string-based</span>
+<span class="sd">                selection used to query the `Dataset` for events passing the</span>
+<span class="sd">                selection. Defaults to None, meaning that all events in the</span>
+<span class="sd">                input files are read.</span>
+<span class="sd">            dtype: Type of the feature tensor on the graph objects returned.</span>
+<span class="sd">            loss_weight_table: Name of the table containing per-event loss</span>
+<span class="sd">                weights.</span>
+<span class="sd">            loss_weight_column: Name of the column in `loss_weight_table`</span>
+<span class="sd">                containing per-event loss weights. This is also the name of the</span>
+<span class="sd">                corresponding attribute assigned to the graph object.</span>
+<span class="sd">            loss_weight_default_value: Default per-event loss weight.</span>
+<span class="sd">                NOTE: This default value is only applied when</span>
+<span class="sd">                `loss_weight_table` and `loss_weight_column` are specified, and</span>
+<span class="sd">                in this case to events with no value in the corresponding</span>
+<span class="sd">                table/column. That is, if no per-event loss weight table/column</span>
+<span class="sd">                is provided, this value is ignored. Defaults to None.</span>
+<span class="sd">            seed: Random number generator seed, used for selecting a random</span>
+<span class="sd">                subset of events when resolving a string-based selection (e.g.,</span>
+<span class="sd">                `"10000 random events ~ event_no % 5 &gt; 0"` or `"20% random</span>
+<span class="sd">                events ~ event_no % 5 &gt; 0"`).</span>
+<span class="sd">            graph_definition: Method that defines the graph representation.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>
+
+        <span class="c1"># Check(s)</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pulsemaps</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
+            <span class="n">pulsemaps</span> <span class="o">=</span> <span class="p">[</span><span class="n">pulsemaps</span><span class="p">]</span>
+
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">))</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">truth</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">))</span>
+
+        <span class="c1"># Resolve reference to `$GRAPHNET` in path(s)</span>
+        <span class="n">path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_resolve_graphnet_paths</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
+
+        <span class="c1"># Member variable(s)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_path</span> <span class="o">=</span> <span class="n">path</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_selection</span> <span class="o">=</span> <span class="kc">None</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_pulsemaps</span> <span class="o">=</span> <span class="n">pulsemaps</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_features</span> <span class="o">=</span> <span class="p">[</span><span class="n">index_column</span><span class="p">]</span> <span class="o">+</span> <span class="n">features</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_truth</span> <span class="o">=</span> <span class="p">[</span><span class="n">index_column</span><span class="p">]</span> <span class="o">+</span> <span class="n">truth</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span> <span class="o">=</span> <span class="n">index_column</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_truth_table</span> <span class="o">=</span> <span class="n">truth_table</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_default_value</span> <span class="o">=</span> <span class="n">loss_weight_default_value</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_graph_definition</span> <span class="o">=</span> <span class="n">graph_definition</span>
+
+        <span class="k">if</span> <span class="n">node_truth</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">node_truth_table</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span>
+            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">node_truth</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
+                <span class="n">node_truth</span> <span class="o">=</span> <span class="p">[</span><span class="n">node_truth</span><span class="p">]</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span> <span class="o">=</span> <span class="n">node_truth</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth_table</span> <span class="o">=</span> <span class="n">node_truth_table</span>
+
+        <span class="k">if</span> <span class="n">string_selection</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
+                <span class="p">(</span>
+                    <span class="s2">"String selection detected.</span><span class="se">\n</span><span class="s2"> "</span>
+                    <span class="sa">f</span><span class="s2">"Accepted strings: </span><span class="si">{</span><span class="n">string_selection</span><span class="si">}</span><span class="se">\n</span><span class="s2"> "</span>
+                    <span class="s2">"All other strings are ignored!"</span>
+                <span class="p">)</span>
+            <span class="p">)</span>
+            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">string_selection</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
+                <span class="n">string_selection</span> <span class="o">=</span> <span class="p">[</span><span class="n">string_selection</span><span class="p">]</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_string_selection</span> <span class="o">=</span> <span class="n">string_selection</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_selection</span> <span class="o">=</span> <span class="kc">None</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_string_selection</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_selection</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"string in </span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_string_selection</span><span class="p">))</span><span class="si">}</span><span class="s2">"</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_column</span> <span class="o">=</span> <span class="n">loss_weight_column</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_table</span> <span class="o">=</span> <span class="n">loss_weight_table</span>
+        <span class="k">if</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_table</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_column</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
+        <span class="p">):</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"Error: no loss weight table specified"</span><span class="p">)</span>
+            <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_table</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span>
+        <span class="k">if</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_column</span> <span class="ow">is</span> <span class="kc">None</span>
+        <span class="p">):</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"Error: no loss weight column specified"</span><span class="p">)</span>
+            <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_column</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_dtype</span> <span class="o">=</span> <span class="n">dtype</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_label_fns</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">Data</span><span class="p">],</span> <span class="n">Any</span><span class="p">]]</span> <span class="o">=</span> <span class="p">{}</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_string_selection_resolver</span> <span class="o">=</span> <span class="n">StringSelectionResolver</span><span class="p">(</span>
+            <span class="bp">self</span><span class="p">,</span>
+            <span class="n">index_column</span><span class="o">=</span><span class="n">index_column</span><span class="p">,</span>
+            <span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Implementation-specific initialisation.</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_init</span><span class="p">()</span>
+
+        <span class="c1"># Set unique indices</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_indices</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]]</span>
+        <span class="k">if</span> <span class="n">selection</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_all_indices</span><span class="p">()</span>
+        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">selection</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_resolve_string_selection_to_indices</span><span class="p">(</span>
+                <span class="n">selection</span>
+            <span class="p">)</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_indices</span> <span class="o">=</span> <span class="n">selection</span>
+
+        <span class="c1"># Purely internal member variables</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_missing_variables</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="p">{}</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_remove_missing_columns</span><span class="p">()</span>
+
+        <span class="c1"># Implementation-specific post-init code.</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_post_init</span><span class="p">()</span>
+
+    <span class="c1"># Properties</span>
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]:</span>
+<span class="w">        </span><span class="sd">"""Path to the file(s) from which this `Dataset` reads."""</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_path</span>
+
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">truth_table</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Name of the table containing event-level truth information."""</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_truth_table</span>
+
+    <span class="c1"># Abstract method(s)</span>
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">_init</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Set internal representation needed to read data from input file."""</span>
+
+    <span class="k">def</span> <span class="nf">_post_init</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Implementation-specific code executed after the main constructor."""</span>
+
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">_get_all_indices</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Return a list of all available values in `self._index_column`."""</span>
+
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">_get_event_index</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">sequential_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Return the event index corresponding to a `sequential_index`."""</span>
+
+<div class="viewcode-block" id="Dataset.query_table">
+<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.query_table">[docs]</a>
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">query_table</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">table</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">columns</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">str</span><span class="p">],</span>
+        <span class="n">sequential_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]]:</span>
+<span class="w">        </span><span class="sd">"""Query a table at a specific index, optionally with some selection.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            table: Table to be queried.</span>
+<span class="sd">            columns: Columns to read out.</span>
+<span class="sd">            sequential_index: Sequentially numbered index</span>
+<span class="sd">                (i.e. in [0,len(self))) of the event to query. This _may_</span>
+<span class="sd">                differ from the indexation used in `self._indices`. If no value</span>
+<span class="sd">                is provided, the entire column is returned.</span>
+<span class="sd">            selection: Selection to be imposed before reading out data.</span>
+<span class="sd">                Defaults to None.</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            List of tuples containing the values in `columns`. If the `table`</span>
+<span class="sd">                contains only scalar data for `columns`, a list of length 1 is</span>
+<span class="sd">                returned</span>
+
+<span class="sd">        Raises:</span>
+<span class="sd">            ColumnMissingException: If one or more element in `columns` is not</span>
+<span class="sd">                present in `table`.</span>
+<span class="sd">        """</span></div>
+
+
+    <span class="c1"># Public method(s)</span>
+<div class="viewcode-block" id="Dataset.add_label">
+<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.add_label">[docs]</a>
+    <span class="k">def</span> <span class="nf">add_label</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">fn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">Data</span><span class="p">],</span> <span class="n">Any</span><span class="p">],</span> <span class="n">key</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Add custom graph label define using function `fn`."""</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">Label</span><span class="p">):</span>
+            <span class="n">key</span> <span class="o">=</span> <span class="n">fn</span><span class="o">.</span><span class="n">key</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span>
+            <span class="n">key</span><span class="p">,</span> <span class="nb">str</span>
+        <span class="p">),</span> <span class="s2">"Please specify a key for the custom label to be added."</span>
+        <span class="k">assert</span> <span class="p">(</span>
+            <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_label_fns</span>
+        <span class="p">),</span> <span class="sa">f</span><span class="s2">"A custom label </span><span class="si">{</span><span class="n">key</span><span class="si">}</span><span class="s2"> has already been defined."</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_label_fns</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">fn</span></div>
+
+
+    <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Return number of graphs in `Dataset`."""</span>
+        <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_indices</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sequential_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Return graph `Data` object at `index`."""</span>
+        <span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="mi">0</span> <span class="o">&lt;=</span> <span class="n">sequential_index</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">)):</span>
+            <span class="k">raise</span> <span class="ne">IndexError</span><span class="p">(</span>
+                <span class="sa">f</span><span class="s2">"Index </span><span class="si">{</span><span class="n">sequential_index</span><span class="si">}</span><span class="s2"> not in range [0, </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="mi">1</span><span class="si">}</span><span class="s2">]"</span>
+            <span class="p">)</span>
+        <span class="n">features</span><span class="p">,</span> <span class="n">truth</span><span class="p">,</span> <span class="n">node_truth</span><span class="p">,</span> <span class="n">loss_weight</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_query</span><span class="p">(</span>
+            <span class="n">sequential_index</span>
+        <span class="p">)</span>
+        <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_graph</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">truth</span><span class="p">,</span> <span class="n">node_truth</span><span class="p">,</span> <span class="n">loss_weight</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">graph</span>
+
+    <span class="c1"># Internal method(s)</span>
+    <span class="k">def</span> <span class="nf">_resolve_string_selection_to_indices</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">selection</span><span class="p">:</span> <span class="nb">str</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Resolve selection as string to list of indices.</span>
+
+<span class="sd">        Selections are expected to have pandas.DataFrame.query-compatible</span>
+<span class="sd">        syntax, e.g., ``` "event_no % 5 &gt; 0" ``` Selections may also specify a</span>
+<span class="sd">        fixed number of events to randomly sample, e.g., ``` "10000 random</span>
+<span class="sd">        events ~ event_no % 5 &gt; 0" "20% random events ~ event_no % 5 &gt; 0" ```</span>
+<span class="sd">        """</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_string_selection_resolver</span><span class="o">.</span><span class="n">resolve</span><span class="p">(</span><span class="n">selection</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_remove_missing_columns</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Remove columns that are not present in the input file.</span>
+
+<span class="sd">        Columns are removed from `self._features` and `self._truth`.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Check if table is completely empty</span>
+        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"Dataset is empty."</span><span class="p">)</span>
+            <span class="k">return</span>
+
+        <span class="c1"># Find missing features</span>
+        <span class="n">missing_features_set</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">)</span>
+        <span class="k">for</span> <span class="n">pulsemap</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pulsemaps</span><span class="p">:</span>
+            <span class="n">missing</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_missing_columns</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">,</span> <span class="n">pulsemap</span><span class="p">)</span>
+            <span class="n">missing_features_set</span> <span class="o">=</span> <span class="n">missing_features_set</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="n">missing</span><span class="p">)</span>
+
+        <span class="n">missing_features</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">missing_features_set</span><span class="p">)</span>
+
+        <span class="c1"># Find missing truth variables</span>
+        <span class="n">missing_truth_variables</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_missing_columns</span><span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_truth</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_truth_table</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Remove missing features</span>
+        <span class="k">if</span> <span class="n">missing_features</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
+                <span class="s2">"Removing the following (missing) features: "</span>
+                <span class="o">+</span> <span class="s2">", "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">missing_features</span><span class="p">)</span>
+            <span class="p">)</span>
+            <span class="k">for</span> <span class="n">missing_feature</span> <span class="ow">in</span> <span class="n">missing_features</span><span class="p">:</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">missing_feature</span><span class="p">)</span>
+
+        <span class="c1"># Remove missing truth variables</span>
+        <span class="k">if</span> <span class="n">missing_truth_variables</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
+                <span class="p">(</span>
+                    <span class="s2">"Removing the following (missing) truth variables: "</span>
+                    <span class="o">+</span> <span class="s2">", "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">missing_truth_variables</span><span class="p">)</span>
+                <span class="p">)</span>
+            <span class="p">)</span>
+            <span class="k">for</span> <span class="n">missing_truth_variable</span> <span class="ow">in</span> <span class="n">missing_truth_variables</span><span class="p">:</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_truth</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">missing_truth_variable</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_check_missing_columns</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+        <span class="n">table</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Return a list missing columns in `table`."""</span>
+        <span class="k">for</span> <span class="n">column</span> <span class="ow">in</span> <span class="n">columns</span><span class="p">:</span>
+            <span class="k">try</span><span class="p">:</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">query_table</span><span class="p">(</span><span class="n">table</span><span class="p">,</span> <span class="p">[</span><span class="n">column</span><span class="p">],</span> <span class="mi">0</span><span class="p">)</span>
+            <span class="k">except</span> <span class="n">ColumnMissingException</span><span class="p">:</span>
+                <span class="k">if</span> <span class="n">table</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_missing_variables</span><span class="p">:</span>
+                    <span class="bp">self</span><span class="o">.</span><span class="n">_missing_variables</span><span class="p">[</span><span class="n">table</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_missing_variables</span><span class="p">[</span><span class="n">table</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">column</span><span class="p">)</span>
+            <span class="k">except</span> <span class="ne">IndexError</span><span class="p">:</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Dataset contains no entries for </span><span class="si">{</span><span class="n">column</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
+
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_missing_variables</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">table</span><span class="p">,</span> <span class="p">[])</span>
+
+    <span class="k">def</span> <span class="nf">_query</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">sequential_index</span><span class="p">:</span> <span class="nb">int</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span>
+        <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="o">...</span><span class="p">]],</span>
+        <span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
+        <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]]],</span>
+        <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span>
+    <span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Query file for event features and truth information.</span>
+
+<span class="sd">        The returned lists have lengths corresponding to the number of pulses</span>
+<span class="sd">        in the event. Their constituent tuples have lengths corresponding to</span>
+<span class="sd">        the number of features/attributes in each output</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            sequential_index: Sequentially numbered index</span>
+<span class="sd">                (i.e. in [0,len(self))) of the event to query. This _may_</span>
+<span class="sd">                differ from the indexation used in `self._indices`.</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            Tuple containing pulse-level event features; event-level truth</span>
+<span class="sd">                information; pulse-level truth information; and event-level</span>
+<span class="sd">                loss weights, respectively.</span>
+<span class="sd">        """</span>
+        <span class="n">features</span> <span class="o">=</span> <span class="p">[]</span>
+        <span class="k">for</span> <span class="n">pulsemap</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pulsemaps</span><span class="p">:</span>
+            <span class="n">features_pulsemap</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_table</span><span class="p">(</span>
+                <span class="n">pulsemap</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">,</span> <span class="n">sequential_index</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_selection</span>
+            <span class="p">)</span>
+            <span class="n">features</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">features_pulsemap</span><span class="p">)</span>
+
+        <span class="n">truth</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_table</span><span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_truth_table</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_truth</span><span class="p">,</span> <span class="n">sequential_index</span>
+        <span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span><span class="p">:</span>
+            <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
+            <span class="n">node_truth</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_table</span><span class="p">(</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth_table</span><span class="p">,</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span><span class="p">,</span>
+                <span class="n">sequential_index</span><span class="p">,</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_selection</span><span class="p">,</span>
+            <span class="p">)</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">node_truth</span> <span class="o">=</span> <span class="kc">None</span>
+
+        <span class="n">loss_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>  <span class="c1"># Default</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_column</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
+            <span class="n">loss_weight_list</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_table</span><span class="p">(</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_table</span><span class="p">,</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_column</span><span class="p">,</span>
+                <span class="n">sequential_index</span><span class="p">,</span>
+            <span class="p">)</span>
+            <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">loss_weight_list</span><span class="p">):</span>
+                <span class="n">loss_weight</span> <span class="o">=</span> <span class="n">loss_weight_list</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="n">loss_weight</span> <span class="o">=</span> <span class="o">-</span><span class="mf">1.0</span>
+
+        <span class="k">return</span> <span class="n">features</span><span class="p">,</span> <span class="n">truth</span><span class="p">,</span> <span class="n">node_truth</span><span class="p">,</span> <span class="n">loss_weight</span>
+
+    <span class="k">def</span> <span class="nf">_create_graph</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="o">...</span><span class="p">]],</span>
+        <span class="n">truth</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
+        <span class="n">node_truth</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">loss_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Create Pytorch Data (i.e. graph) object.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            features: List of tuples, containing event features.</span>
+<span class="sd">            truth: List of tuples, containing truth information.</span>
+<span class="sd">            node_truth: List of tuples, containing node-level truth.</span>
+<span class="sd">            loss_weight: A weight associated with the event for weighing the</span>
+<span class="sd">                loss.</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            Graph object.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Convert nested list to simple dict</span>
+        <span class="n">truth_dict</span> <span class="o">=</span> <span class="p">{</span>
+            <span class="n">key</span><span class="p">:</span> <span class="n">truth</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">key</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_truth</span><span class="p">)</span>
+        <span class="p">}</span>
+
+        <span class="c1"># Define custom labels</span>
+        <span class="n">labels_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_labels</span><span class="p">(</span><span class="n">truth_dict</span><span class="p">)</span>
+
+        <span class="c1"># Convert nested list to simple dict</span>
+        <span class="k">if</span> <span class="n">node_truth</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">node_truth_array</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">node_truth</span><span class="p">)</span>
+            <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
+            <span class="n">node_truth_dict</span> <span class="o">=</span> <span class="p">{</span>
+                <span class="n">key</span><span class="p">:</span> <span class="n">node_truth_array</span><span class="p">[:,</span> <span class="n">index</span><span class="p">]</span>
+                <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">key</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span><span class="p">)</span>
+            <span class="p">}</span>
+
+        <span class="c1"># Create list of truth dicts with labels</span>
+        <span class="n">truth_dicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">labels_dict</span><span class="p">,</span> <span class="n">truth_dict</span><span class="p">]</span>
+        <span class="k">if</span> <span class="n">node_truth</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">truth_dicts</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">node_truth_dict</span><span class="p">)</span>
+
+        <span class="c1"># Catch cases with no reconstructed pulses</span>
+        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">features</span><span class="p">):</span>
+            <span class="n">node_features</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">features</span><span class="p">)[</span>
+                <span class="p">:,</span> <span class="mi">1</span><span class="p">:</span>
+            <span class="p">]</span>  <span class="c1"># first entry is index column</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">node_features</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([])</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
+
+        <span class="c1"># Construct graph data object</span>
+        <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_graph_definition</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
+        <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_graph_definition</span><span class="p">(</span>
+            <span class="n">node_features</span><span class="o">=</span><span class="n">node_features</span><span class="p">,</span>
+            <span class="n">node_feature_names</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">[</span>
+                <span class="mi">1</span><span class="p">:</span>
+            <span class="p">],</span>  <span class="c1"># first entry is index column</span>
+            <span class="n">truth_dicts</span><span class="o">=</span><span class="n">truth_dicts</span><span class="p">,</span>
+            <span class="n">custom_label_functions</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_label_fns</span><span class="p">,</span>
+            <span class="n">loss_weight_column</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_column</span><span class="p">,</span>
+            <span class="n">loss_weight</span><span class="o">=</span><span class="n">loss_weight</span><span class="p">,</span>
+            <span class="n">loss_weight_default_value</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_default_value</span><span class="p">,</span>
+            <span class="n">data_path</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span>
+        <span class="p">)</span>
+        <span class="k">return</span> <span class="n">graph</span>
+
+    <span class="k">def</span> <span class="nf">_get_labels</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">truth_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Return dictionary of  labels, to be added as graph attributes."""</span>
+        <span class="k">if</span> <span class="s2">"pid"</span> <span class="ow">in</span> <span class="n">truth_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+            <span class="n">abs_pid</span> <span class="o">=</span> <span class="nb">abs</span><span class="p">(</span><span class="n">truth_dict</span><span class="p">[</span><span class="s2">"pid"</span><span class="p">])</span>
+            <span class="n">sim_type</span> <span class="o">=</span> <span class="n">truth_dict</span><span class="p">[</span><span class="s2">"sim_type"</span><span class="p">]</span>
+
+            <span class="n">labels_dict</span> <span class="o">=</span> <span class="p">{</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span><span class="p">:</span> <span class="n">truth_dict</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span><span class="p">],</span>
+                <span class="s2">"muon"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">abs_pid</span> <span class="o">==</span> <span class="mi">13</span><span class="p">),</span>
+                <span class="s2">"muon_stopped"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">truth_dict</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"stopped_muon"</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">),</span>
+                <span class="s2">"noise"</span><span class="p">:</span> <span class="nb">int</span><span class="p">((</span><span class="n">abs_pid</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">sim_type</span> <span class="o">!=</span> <span class="s2">"data"</span><span class="p">)),</span>
+                <span class="s2">"neutrino"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span>
+                    <span class="p">(</span><span class="n">abs_pid</span> <span class="o">!=</span> <span class="mi">13</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">abs_pid</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">)</span>
+                <span class="p">),</span>  <span class="c1"># @TODO: `abs_pid in [12,14,16]`?</span>
+                <span class="s2">"v_e"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">abs_pid</span> <span class="o">==</span> <span class="mi">12</span><span class="p">),</span>
+                <span class="s2">"v_u"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">abs_pid</span> <span class="o">==</span> <span class="mi">14</span><span class="p">),</span>
+                <span class="s2">"v_t"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">abs_pid</span> <span class="o">==</span> <span class="mi">16</span><span class="p">),</span>
+                <span class="s2">"track"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span>
+                    <span class="p">(</span><span class="n">abs_pid</span> <span class="o">==</span> <span class="mi">14</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">truth_dict</span><span class="p">[</span><span class="s2">"interaction_type"</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>
+                <span class="p">),</span>
+                <span class="s2">"dbang"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_dbang_label</span><span class="p">(</span><span class="n">truth_dict</span><span class="p">),</span>
+                <span class="s2">"corsika"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">abs_pid</span> <span class="o">&gt;</span> <span class="mi">20</span><span class="p">),</span>
+            <span class="p">}</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">labels_dict</span> <span class="o">=</span> <span class="p">{</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span><span class="p">:</span> <span class="n">truth_dict</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span><span class="p">],</span>
+                <span class="s2">"muon"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
+                <span class="s2">"muon_stopped"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
+                <span class="s2">"noise"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
+                <span class="s2">"neutrino"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
+                <span class="s2">"v_e"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
+                <span class="s2">"v_u"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
+                <span class="s2">"v_t"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
+                <span class="s2">"track"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
+                <span class="s2">"dbang"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
+                <span class="s2">"corsika"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
+            <span class="p">}</span>
+        <span class="k">return</span> <span class="n">labels_dict</span>
+
+    <span class="k">def</span> <span class="nf">_get_dbang_label</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">truth_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Get label for double-bang classification."""</span>
+        <span class="k">try</span><span class="p">:</span>
+            <span class="n">label</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">truth_dict</span><span class="p">[</span><span class="s2">"dbang_decay_length"</span><span class="p">]</span> <span class="o">&gt;</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
+            <span class="k">return</span> <span class="n">label</span>
+        <span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span>
+            <span class="k">return</span> <span class="o">-</span><span class="mi">1</span></div>
+
+
+
+<div class="viewcode-block" id="EnsembleDataset">
+<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.EnsembleDataset">[docs]</a>
+<span class="k">class</span> <span class="nc">EnsembleDataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">ConcatDataset</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Construct a single dataset from a collection of datasets."""</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">datasets</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="n">Dataset</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct a single dataset from a collection of datasets.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            datasets: A collection of Datasets</span>
+<span class="sd">        """</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">datasets</span><span class="o">=</span><span class="n">datasets</span><span class="p">)</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/data/dataset/parquet/parquet_dataset.html b/_modules/graphnet/data/dataset/parquet/parquet_dataset.html
new file mode 100644
index 000000000..35d6b34d7
--- /dev/null
+++ b/_modules/graphnet/data/dataset/parquet/parquet_dataset.html
@@ -0,0 +1,500 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.data.dataset.parquet.parquet_dataset &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/data/dataset/parquet/parquet_dataset" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.data.dataset.parquet.parquet_dataset </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../../"versions.json"",
+        target_loc = "../../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-data-dataset-parquet-parquet-dataset--page-root">Source code for graphnet.data.dataset.parquet.parquet_dataset</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""`Dataset` class(es) for reading from Parquet files."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">cast</span>
+
+<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
+<span class="kn">import</span> <span class="nn">awkward</span> <span class="k">as</span> <span class="nn">ak</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.data.dataset.dataset</span> <span class="kn">import</span> <span class="n">Dataset</span><span class="p">,</span> <span class="n">ColumnMissingException</span>
+
+
+<div class="viewcode-block" id="ParquetDataset">
+<a class="viewcode-back" href="../../../../../api/graphnet.data.dataset.parquet.parquet_dataset.html#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset">[docs]</a>
+<span class="k">class</span> <span class="nc">ParquetDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Pytorch dataset for reading from Parquet files."""</span>
+
+    <span class="c1"># Implementing abstract method(s)</span>
+    <span class="k">def</span> <span class="nf">_init</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="c1"># Check(s)</span>
+        <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
+
+            <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span>
+
+            <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span>
+                <span class="s2">".parquet"</span>
+            <span class="p">),</span> <span class="sa">f</span><span class="s2">"Format of input file `</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="si">}</span><span class="s2">` is not supported"</span>
+
+        <span class="k">assert</span> <span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span> <span class="ow">is</span> <span class="kc">None</span>
+        <span class="p">),</span> <span class="s2">"Argument `node_truth` is currently not supported."</span>
+        <span class="k">assert</span> <span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth_table</span> <span class="ow">is</span> <span class="kc">None</span>
+        <span class="p">),</span> <span class="s2">"Argument `node_truth_table` is currently not supported."</span>
+        <span class="k">assert</span> <span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_string_selection</span> <span class="ow">is</span> <span class="kc">None</span>
+        <span class="p">),</span> <span class="s2">"Argument `string_selection` is currently not supported"</span>
+
+        <span class="c1"># Set custom member variable(s)</span>
+        <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_parquet_hook</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">from_parquet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> <span class="n">lazy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_parquet_hook</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
+                <span class="n">ak</span><span class="o">.</span><span class="n">from_parquet</span><span class="p">(</span><span class="n">file</span><span class="p">)</span> <span class="k">for</span> <span class="n">file</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_path</span>
+            <span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_get_all_indices</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
+        <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span>
+            <span class="nb">len</span><span class="p">(</span>
+                <span class="n">ak</span><span class="o">.</span><span class="n">to_numpy</span><span class="p">(</span>
+                    <span class="bp">self</span><span class="o">.</span><span class="n">_parquet_hook</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_truth_table</span><span class="p">][</span><span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span><span class="p">]</span>
+                <span class="p">)</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
+            <span class="p">)</span>
+        <span class="p">)</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
+
+    <span class="k">def</span> <span class="nf">_get_event_index</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">sequential_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
+        <span class="n">index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span>
+        <span class="k">if</span> <span class="n">sequential_index</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">index</span> <span class="o">=</span> <span class="kc">None</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">index</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">_indices</span><span class="p">)[</span><span class="n">sequential_index</span><span class="p">]</span>
+
+        <span class="k">return</span> <span class="n">index</span>
+
+    <span class="k">def</span> <span class="nf">_format_dictionary_result</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">dictionary</span><span class="p">:</span> <span class="n">Dict</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]]:</span>
+<span class="w">        </span><span class="sd">"""Convert the output of `ak.to_list()` into a list of tuples."""</span>
+        <span class="c1"># All scalar values</span>
+        <span class="k">if</span> <span class="nb">all</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">isscalar</span><span class="p">,</span> <span class="n">dictionary</span><span class="o">.</span><span class="n">values</span><span class="p">())):</span>
+            <span class="k">return</span> <span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="n">dictionary</span><span class="o">.</span><span class="n">values</span><span class="p">())]</span>
+
+        <span class="c1"># All arrays should have same length</span>
+        <span class="n">array_lengths</span> <span class="o">=</span> <span class="p">[</span>
+            <span class="nb">len</span><span class="p">(</span><span class="n">values</span><span class="p">)</span>
+            <span class="k">for</span> <span class="n">values</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="o">.</span><span class="n">values</span><span class="p">()</span>
+            <span class="k">if</span> <span class="ow">not</span> <span class="n">np</span><span class="o">.</span><span class="n">isscalar</span><span class="p">(</span><span class="n">values</span><span class="p">)</span>
+        <span class="p">]</span>
+        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">array_lengths</span><span class="p">))</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="p">(</span>
+            <span class="sa">f</span><span class="s2">"Arrays in </span><span class="si">{</span><span class="n">dictionary</span><span class="si">}</span><span class="s2"> have differing lengths "</span>
+            <span class="sa">f</span><span class="s2">"(</span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">array_lengths</span><span class="p">)</span><span class="si">}</span><span class="s2">)."</span>
+        <span class="p">)</span>
+        <span class="n">nb_elements</span> <span class="o">=</span> <span class="n">array_lengths</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
+
+        <span class="c1"># Broadcast scalars</span>
+        <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span>
+            <span class="n">value</span> <span class="o">=</span> <span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
+            <span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">isscalar</span><span class="p">(</span><span class="n">value</span><span class="p">):</span>
+                <span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span>
+                    <span class="n">value</span><span class="p">,</span> <span class="n">repeats</span><span class="o">=</span><span class="n">nb_elements</span>
+                <span class="p">)</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
+
+        <span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">dictionary</span><span class="o">.</span><span class="n">values</span><span class="p">()))))</span>
+
+<div class="viewcode-block" id="ParquetDataset.query_table">
+<a class="viewcode-back" href="../../../../../api/graphnet.data.dataset.parquet.parquet_dataset.html#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table">[docs]</a>
+    <span class="k">def</span> <span class="nf">query_table</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">table</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">columns</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">str</span><span class="p">],</span>
+        <span class="n">sequential_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]]:</span>
+<span class="w">        </span><span class="sd">"""Query table at a specific index, optionally with some selection."""</span>
+        <span class="c1"># Check(s)</span>
+        <span class="k">assert</span> <span class="p">(</span>
+            <span class="n">selection</span> <span class="ow">is</span> <span class="kc">None</span>
+        <span class="p">),</span> <span class="s2">"Argument `selection` is currently not supported"</span>
+
+        <span class="n">index</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_event_index</span><span class="p">(</span><span class="n">sequential_index</span><span class="p">)</span>
+
+        <span class="k">try</span><span class="p">:</span>
+            <span class="k">if</span> <span class="n">index</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+                <span class="n">ak_array</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_parquet_hook</span><span class="p">[</span><span class="n">table</span><span class="p">][</span><span class="n">columns</span><span class="p">][:]</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="n">ak_array</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_parquet_hook</span><span class="p">[</span><span class="n">table</span><span class="p">][</span><span class="n">columns</span><span class="p">][</span><span class="n">index</span><span class="p">]</span>
+        <span class="k">except</span> <span class="ne">ValueError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
+            <span class="k">if</span> <span class="s2">"does not exist (not in record)"</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="n">e</span><span class="p">):</span>
+                <span class="k">raise</span> <span class="n">ColumnMissingException</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">e</span><span class="p">))</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="k">raise</span> <span class="n">e</span>
+
+        <span class="n">output</span> <span class="o">=</span> <span class="n">ak_array</span><span class="o">.</span><span class="n">to_list</span><span class="p">()</span>
+
+        <span class="n">result</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]]</span> <span class="o">=</span> <span class="p">[]</span>
+
+        <span class="c1"># Querying single index</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
+            <span class="k">assert</span> <span class="nb">list</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="o">==</span> <span class="n">columns</span>
+            <span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_format_dictionary_result</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>
+
+        <span class="c1"># Querying entire columm</span>
+        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
+            <span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">output</span><span class="p">:</span>
+                <span class="k">assert</span> <span class="nb">list</span><span class="p">(</span><span class="n">dictionary</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="o">==</span> <span class="n">columns</span>
+                <span class="n">result</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_format_dictionary_result</span><span class="p">(</span><span class="n">dictionary</span><span class="p">))</span>
+
+        <span class="k">return</span> <span class="n">result</span></div>
+</div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html b/_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html
new file mode 100644
index 000000000..a6734e1c8
--- /dev/null
+++ b/_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html
@@ -0,0 +1,515 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.data.dataset.sqlite.sqlite_dataset &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/data/dataset/sqlite/sqlite_dataset" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.data.dataset.sqlite.sqlite_dataset </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../../"versions.json"",
+        target_loc = "../../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-data-dataset-sqlite-sqlite-dataset--page-root">Source code for graphnet.data.dataset.sqlite.sqlite_dataset</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""`Dataset` class(es) for reading data from SQLite databases."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
+<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span>
+<span class="kn">import</span> <span class="nn">sqlite3</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.data.dataset.dataset</span> <span class="kn">import</span> <span class="n">Dataset</span><span class="p">,</span> <span class="n">ColumnMissingException</span>
+
+
+<div class="viewcode-block" id="SQLiteDataset">
+<a class="viewcode-back" href="../../../../../api/graphnet.data.dataset.sqlite.sqlite_dataset.html#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset">[docs]</a>
+<span class="k">class</span> <span class="nc">SQLiteDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Pytorch dataset for reading data from SQLite databases."""</span>
+
+    <span class="c1"># Implementing abstract method(s)</span>
+    <span class="k">def</span> <span class="nf">_init</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="c1"># Check(s)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_path</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections_established</span> <span class="o">=</span> <span class="kc">False</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">sqlite3</span><span class="o">.</span><span class="n">Connection</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span> <span class="o">=</span> <span class="kc">None</span>
+            <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span>
+            <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span>
+                <span class="s2">".db"</span>
+            <span class="p">),</span> <span class="sa">f</span><span class="s2">"Format of input file `</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="si">}</span><span class="s2">` is not supported."</span>
+
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_current_database</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+
+        <span class="c1"># Set custom member variable(s)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_features_string</span> <span class="o">=</span> <span class="s2">", "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_truth_string</span> <span class="o">=</span> <span class="s2">", "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_truth</span><span class="p">)</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth_string</span> <span class="o">=</span> <span class="s2">", "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span><span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">sqlite3</span><span class="o">.</span><span class="n">Connection</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+
+    <span class="k">def</span> <span class="nf">_post_init</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_close_connection</span><span class="p">()</span>
+
+<div class="viewcode-block" id="SQLiteDataset.query_table">
+<a class="viewcode-back" href="../../../../../api/graphnet.data.dataset.sqlite.sqlite_dataset.html#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table">[docs]</a>
+    <span class="k">def</span> <span class="nf">query_table</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">table</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">columns</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">str</span><span class="p">],</span>
+        <span class="n">sequential_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]]:</span>
+<span class="w">        </span><span class="sd">"""Query table at a specific index, optionally with some selection."""</span>
+        <span class="c1"># Check(s)</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">columns</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
+            <span class="n">columns</span> <span class="o">=</span> <span class="s2">", "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">columns</span><span class="p">)</span>
+
+        <span class="k">if</span> <span class="ow">not</span> <span class="n">selection</span><span class="p">:</span>  <span class="c1"># I.e., `None` or `""`</span>
+            <span class="n">selection</span> <span class="o">=</span> <span class="s2">"1=1"</span>  <span class="c1"># Identically true, to select all</span>
+
+        <span class="n">index</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_event_index</span><span class="p">(</span><span class="n">sequential_index</span><span class="p">)</span>
+
+        <span class="c1"># Query table</span>
+        <span class="k">assert</span> <span class="n">index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_establish_connection</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>
+        <span class="k">try</span><span class="p">:</span>
+            <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span>
+            <span class="k">if</span> <span class="n">sequential_index</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+                <span class="n">combined_selections</span> <span class="o">=</span> <span class="n">selection</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="n">combined_selections</span> <span class="o">=</span> <span class="p">(</span>
+                    <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="n">index</span><span class="si">}</span><span class="s2"> and </span><span class="si">{</span><span class="n">selection</span><span class="si">}</span><span class="s2">"</span>
+                <span class="p">)</span>
+
+            <span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span><span class="o">.</span><span class="n">execute</span><span class="p">(</span>
+                <span class="sa">f</span><span class="s2">"SELECT </span><span class="si">{</span><span class="n">columns</span><span class="si">}</span><span class="s2"> FROM </span><span class="si">{</span><span class="n">table</span><span class="si">}</span><span class="s2"> WHERE "</span>
+                <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">combined_selections</span><span class="si">}</span><span class="s2">"</span>
+            <span class="p">)</span><span class="o">.</span><span class="n">fetchall</span><span class="p">()</span>
+        <span class="k">except</span> <span class="n">sqlite3</span><span class="o">.</span><span class="n">OperationalError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
+            <span class="k">if</span> <span class="s2">"no such column"</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="n">e</span><span class="p">):</span>
+                <span class="k">raise</span> <span class="n">ColumnMissingException</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">e</span><span class="p">))</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="k">raise</span> <span class="n">e</span>
+        <span class="k">return</span> <span class="n">result</span></div>
+
+
+    <span class="k">def</span> <span class="nf">_get_all_indices</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_establish_connection</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
+        <span class="n">indices</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_sql_query</span><span class="p">(</span>
+            <span class="sa">f</span><span class="s2">"SELECT </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span><span class="si">}</span><span class="s2"> FROM </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_truth_table</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span>
+        <span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_close_connection</span><span class="p">()</span>
+        <span class="k">return</span> <span class="n">indices</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
+
+    <span class="k">def</span> <span class="nf">_get_event_index</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">sequential_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
+        <span class="n">index</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
+        <span class="k">if</span> <span class="n">sequential_index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">index_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_indices</span><span class="p">[</span><span class="n">sequential_index</span><span class="p">]</span>
+            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+                <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">index_</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span>
+                <span class="n">index</span> <span class="o">=</span> <span class="n">index_</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">index_</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
+                <span class="n">index</span> <span class="o">=</span> <span class="n">index_</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
+        <span class="k">return</span> <span class="n">index</span>
+
+    <span class="c1"># Custom, internal method(s)</span>
+    <span class="c1"># @TODO: Is it necessary to return anything here?</span>
+    <span class="k">def</span> <span class="nf">_establish_connection</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">i</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">"SQLiteDataset"</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Make sure that a sqlite3 connection is open."""</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span>
+            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="o">=</span> <span class="n">sqlite3</span><span class="o">.</span><span class="n">connect</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">)</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_indices</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
+            <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">indices</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
+            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+                <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections_established</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span>
+                    <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections</span> <span class="o">=</span> <span class="p">[]</span>
+                    <span class="k">for</span> <span class="n">database</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span><span class="p">:</span>
+                        <span class="n">con</span> <span class="o">=</span> <span class="n">sqlite3</span><span class="o">.</span><span class="n">connect</span><span class="p">(</span><span class="n">database</span><span class="p">)</span>
+                        <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">con</span><span class="p">)</span>
+                    <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections_established</span> <span class="o">=</span> <span class="kc">True</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections</span><span class="p">[</span><span class="n">indices</span><span class="p">[</span><span class="mi">1</span><span class="p">]]</span>
+            <span class="k">if</span> <span class="n">indices</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_current_database</span><span class="p">:</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections</span><span class="p">[</span><span class="n">indices</span><span class="p">[</span><span class="mi">1</span><span class="p">]]</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_current_database</span> <span class="o">=</span> <span class="n">indices</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
+        <span class="k">return</span> <span class="bp">self</span>
+
+    <span class="c1"># @TODO: Is it necessary to return anything here?</span>
+    <span class="k">def</span> <span class="nf">_close_connection</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">"SQLiteDataset"</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Make sure that no sqlite3 connection is open.</span>
+
+<span class="sd">        This is necessary to calls this before passing to</span>
+<span class="sd">        `torch.DataLoader` such that the dataset replica on each worker</span>
+<span class="sd">        is required to create its own connection (thereby avoiding</span>
+<span class="sd">        `sqlite3.DatabaseError: database disk image is malformed` errors</span>
+<span class="sd">        due to inability to use sqlite3 connection accross processes.</span>
+<span class="sd">        """</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
+            <span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="o">=</span> <span class="kc">None</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections_established</span><span class="p">:</span>
+                <span class="k">for</span> <span class="n">con</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections</span><span class="p">:</span>
+                    <span class="n">con</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
+                <span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections_established</span> <span class="o">=</span> <span class="kc">False</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="o">=</span> <span class="kc">None</span>
+        <span class="k">return</span> <span class="bp">self</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html b/_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html
new file mode 100644
index 000000000..44acbb9c5
--- /dev/null
+++ b/_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html
@@ -0,0 +1,515 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.data.dataset.sqlite.sqlite_dataset_perturbed &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.data.dataset.sqlite.sqlite_dataset_perturbed </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../../"versions.json"",
+        target_loc = "../../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root">Source code for graphnet.data.dataset.sqlite.sqlite_dataset_perturbed</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""`Dataset` class(es) for reading perturbed data from SQLite databases."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
+
+<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
+<span class="kn">from</span> <span class="nn">numpy.random</span> <span class="kn">import</span> <span class="n">default_rng</span><span class="p">,</span> <span class="n">Generator</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+
+<span class="kn">from</span> <span class="nn">.sqlite_dataset</span> <span class="kn">import</span> <span class="n">SQLiteDataset</span>
+
+
+<div class="viewcode-block" id="SQLiteDatasetPerturbed">
+<a class="viewcode-back" href="../../../../../api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html#graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed">[docs]</a>
+<span class="k">class</span> <span class="nc">SQLiteDatasetPerturbed</span><span class="p">(</span><span class="n">SQLiteDataset</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Pytorch dataset for reading perturbed data from SQLite databases.</span>
+
+<span class="sd">    This including a pre-processing step, where the input data is randomly</span>
+<span class="sd">    perturbed according to given per-feature "noise" levels. This is intended</span>
+<span class="sd">    to test the stability of a trained model under small changes to the input</span>
+<span class="sd">    parameters.</span>
+<span class="sd">    """</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">path</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span>
+        <span class="n">pulsemaps</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span>
+        <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+        <span class="n">truth</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+        <span class="o">*</span><span class="p">,</span>
+        <span class="n">perturbation_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
+        <span class="n">node_truth</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">index_column</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"event_no"</span><span class="p">,</span>
+        <span class="n">truth_table</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"truth"</span><span class="p">,</span>
+        <span class="n">node_truth_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">string_selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">dtype</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
+        <span class="n">loss_weight_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">loss_weight_column</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">loss_weight_default_value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Generator</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct SQLiteDatasetPerturbed.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            path: Path to the file(s) from which this `Dataset` should read.</span>
+<span class="sd">            pulsemaps: Name(s) of the pulse map series that should be used to</span>
+<span class="sd">                construct the nodes on the individual graph objects, and their</span>
+<span class="sd">                features. Multiple pulse series maps can be used, e.g., when</span>
+<span class="sd">                different DOM types are stored in different maps.</span>
+<span class="sd">            features: List of columns in the input files that should be used as</span>
+<span class="sd">                node features on the graph objects.</span>
+<span class="sd">            truth: List of event-level columns in the input files that should</span>
+<span class="sd">                be used added as attributes on the  graph objects.</span>
+<span class="sd">            perturbation_dict (Dict[str, float]): Dictionary mapping a feature</span>
+<span class="sd">                name to a standard deviation according to which the values for</span>
+<span class="sd">                this feature should be randomly perturbed.</span>
+<span class="sd">            node_truth: List of node-level columns in the input files that</span>
+<span class="sd">                should be used added as attributes on the graph objects.</span>
+<span class="sd">            index_column: Name of the column in the input files that contains</span>
+<span class="sd">                unique indicies to identify and map events across tables.</span>
+<span class="sd">            truth_table: Name of the table containing event-level truth</span>
+<span class="sd">                information.</span>
+<span class="sd">            node_truth_table: Name of the table containing node-level truth</span>
+<span class="sd">                information.</span>
+<span class="sd">            string_selection: Subset of strings for which data should be read</span>
+<span class="sd">                and used to construct graph objects. Defaults to None, meaning</span>
+<span class="sd">                all strings for which data exists are used.</span>
+<span class="sd">            selection: List of indicies (in `index_column`) of the events in</span>
+<span class="sd">                the input files that should be read. Defaults to None, meaning</span>
+<span class="sd">                that all events in the input files are read.</span>
+<span class="sd">            dtype: Type of the feature tensor on the graph objects returned.</span>
+<span class="sd">            loss_weight_table: Name of the table containing per-event loss</span>
+<span class="sd">                weights.</span>
+<span class="sd">            loss_weight_column: Name of the column in `loss_weight_table`</span>
+<span class="sd">                containing per-event loss weights. This is also the name of the</span>
+<span class="sd">                corresponding attribute assigned to the graph object.</span>
+<span class="sd">            loss_weight_default_value: Default per-event loss weight.</span>
+<span class="sd">                NOTE: This default value is only applied when</span>
+<span class="sd">                `loss_weight_table` and `loss_weight_column` are specified, and</span>
+<span class="sd">                in this case to events with no value in the corresponding</span>
+<span class="sd">                table/column. That is, if no per-event loss weight table/column</span>
+<span class="sd">                is provided, this value is ignored. Defaults to None.</span>
+<span class="sd">            seed: Optional seed for random number generation. Defaults to None.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
+            <span class="n">path</span><span class="o">=</span><span class="n">path</span><span class="p">,</span>
+            <span class="n">pulsemaps</span><span class="o">=</span><span class="n">pulsemaps</span><span class="p">,</span>
+            <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">,</span>
+            <span class="n">truth</span><span class="o">=</span><span class="n">truth</span><span class="p">,</span>
+            <span class="n">node_truth</span><span class="o">=</span><span class="n">node_truth</span><span class="p">,</span>
+            <span class="n">index_column</span><span class="o">=</span><span class="n">index_column</span><span class="p">,</span>
+            <span class="n">truth_table</span><span class="o">=</span><span class="n">truth_table</span><span class="p">,</span>
+            <span class="n">node_truth_table</span><span class="o">=</span><span class="n">node_truth_table</span><span class="p">,</span>
+            <span class="n">string_selection</span><span class="o">=</span><span class="n">string_selection</span><span class="p">,</span>
+            <span class="n">selection</span><span class="o">=</span><span class="n">selection</span><span class="p">,</span>
+            <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
+            <span class="n">loss_weight_table</span><span class="o">=</span><span class="n">loss_weight_table</span><span class="p">,</span>
+            <span class="n">loss_weight_column</span><span class="o">=</span><span class="n">loss_weight_column</span><span class="p">,</span>
+            <span class="n">loss_weight_default_value</span><span class="o">=</span><span class="n">loss_weight_default_value</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Custom member variables</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">perturbation_dict</span><span class="p">,</span> <span class="nb">dict</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">perturbation_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span>
+            <span class="n">perturbation_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
+        <span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_perturbation_dict</span> <span class="o">=</span> <span class="n">perturbation_dict</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_perturbation_cols</span> <span class="o">=</span> <span class="p">[</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_perturbation_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
+        <span class="p">]</span>
+
+        <span class="k">if</span> <span class="n">seed</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">seed</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">rng</span> <span class="o">=</span> <span class="n">default_rng</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
+            <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">seed</span><span class="p">,</span> <span class="n">Generator</span><span class="p">):</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">rng</span> <span class="o">=</span> <span class="n">seed</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
+                    <span class="s2">"Invalid seed. Must be an int or a numpy Generator."</span>
+                <span class="p">)</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">rng</span> <span class="o">=</span> <span class="n">default_rng</span><span class="p">()</span>
+
+    <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sequential_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Return graph `Data` object at `index`."""</span>
+        <span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="mi">0</span> <span class="o">&lt;=</span> <span class="n">sequential_index</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">)):</span>
+            <span class="k">raise</span> <span class="ne">IndexError</span><span class="p">(</span>
+                <span class="sa">f</span><span class="s2">"Index </span><span class="si">{</span><span class="n">sequential_index</span><span class="si">}</span><span class="s2"> not in range [0, </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="mi">1</span><span class="si">}</span><span class="s2">]"</span>
+            <span class="p">)</span>
+        <span class="n">features</span><span class="p">,</span> <span class="n">truth</span><span class="p">,</span> <span class="n">node_truth</span><span class="p">,</span> <span class="n">loss_weight</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_query</span><span class="p">(</span>
+            <span class="n">sequential_index</span>
+        <span class="p">)</span>
+        <span class="n">perturbed_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_perturb_features</span><span class="p">(</span><span class="n">features</span><span class="p">)</span>
+        <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_graph</span><span class="p">(</span>
+            <span class="n">perturbed_features</span><span class="p">,</span> <span class="n">truth</span><span class="p">,</span> <span class="n">node_truth</span><span class="p">,</span> <span class="n">loss_weight</span>
+        <span class="p">)</span>
+        <span class="k">return</span> <span class="n">graph</span>
+
+    <span class="k">def</span> <span class="nf">_perturb_features</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="o">...</span><span class="p">]]</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="o">...</span><span class="p">]]:</span>
+        <span class="n">features_array</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">features</span><span class="p">)</span>
+        <span class="n">perturbed_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rng</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span>
+            <span class="n">loc</span><span class="o">=</span><span class="n">features_array</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_perturbation_cols</span><span class="p">],</span>
+            <span class="n">scale</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span>
+                <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_perturbation_dict</span><span class="o">.</span><span class="n">values</span><span class="p">()),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float</span>
+            <span class="p">),</span>
+        <span class="p">)</span>
+        <span class="n">features_array</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_perturbation_cols</span><span class="p">]</span> <span class="o">=</span> <span class="n">perturbed_features</span>
+        <span class="k">return</span> <span class="n">features_array</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/data/extractors/i3extractor.html b/_modules/graphnet/data/extractors/i3extractor.html
index cae615098..0680557f4 100644
--- a/_modules/graphnet/data/extractors/i3extractor.html
+++ b/_modules/graphnet/data/extractors/i3extractor.html
@@ -464,7 +464,7 @@ <h1 id="modules-graphnet-data-extractors-i3extractor--page-root">Source code for
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/extractors/i3featureextractor.html b/_modules/graphnet/data/extractors/i3featureextractor.html
index 06d9d3e68..de384a97e 100644
--- a/_modules/graphnet/data/extractors/i3featureextractor.html
+++ b/_modules/graphnet/data/extractors/i3featureextractor.html
@@ -647,7 +647,7 @@ <h1 id="modules-graphnet-data-extractors-i3featureextractor--page-root">Source c
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/extractors/i3genericextractor.html b/_modules/graphnet/data/extractors/i3genericextractor.html
index 0324f9087..9c7d786f3 100644
--- a/_modules/graphnet/data/extractors/i3genericextractor.html
+++ b/_modules/graphnet/data/extractors/i3genericextractor.html
@@ -636,7 +636,7 @@ <h1 id="modules-graphnet-data-extractors-i3genericextractor--page-root">Source c
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/extractors/i3hybridrecoextractor.html b/_modules/graphnet/data/extractors/i3hybridrecoextractor.html
index 21824f54b..ee80f730f 100644
--- a/_modules/graphnet/data/extractors/i3hybridrecoextractor.html
+++ b/_modules/graphnet/data/extractors/i3hybridrecoextractor.html
@@ -400,7 +400,7 @@ <h1 id="modules-graphnet-data-extractors-i3hybridrecoextractor--page-root">Sourc
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html b/_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html
index 2a3a45023..d295f81a2 100644
--- a/_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html
+++ b/_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html
@@ -407,7 +407,7 @@ <h1 id="modules-graphnet-data-extractors-i3ntmuonlabelsextractor--page-root">Sou
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/extractors/i3particleextractor.html b/_modules/graphnet/data/extractors/i3particleextractor.html
index 8d17ef0ee..60463b59a 100644
--- a/_modules/graphnet/data/extractors/i3particleextractor.html
+++ b/_modules/graphnet/data/extractors/i3particleextractor.html
@@ -392,7 +392,7 @@ <h1 id="modules-graphnet-data-extractors-i3particleextractor--page-root">Source
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/extractors/i3pisaextractor.html b/_modules/graphnet/data/extractors/i3pisaextractor.html
index 635ba4537..0752dc39a 100644
--- a/_modules/graphnet/data/extractors/i3pisaextractor.html
+++ b/_modules/graphnet/data/extractors/i3pisaextractor.html
@@ -385,7 +385,7 @@ <h1 id="modules-graphnet-data-extractors-i3pisaextractor--page-root">Source code
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/extractors/i3quesoextractor.html b/_modules/graphnet/data/extractors/i3quesoextractor.html
index bcac49e76..6ad9572a6 100644
--- a/_modules/graphnet/data/extractors/i3quesoextractor.html
+++ b/_modules/graphnet/data/extractors/i3quesoextractor.html
@@ -395,7 +395,7 @@ <h1 id="modules-graphnet-data-extractors-i3quesoextractor--page-root">Source cod
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/extractors/i3retroextractor.html b/_modules/graphnet/data/extractors/i3retroextractor.html
index dea96d47a..b76bed436 100644
--- a/_modules/graphnet/data/extractors/i3retroextractor.html
+++ b/_modules/graphnet/data/extractors/i3retroextractor.html
@@ -467,7 +467,7 @@ <h1 id="modules-graphnet-data-extractors-i3retroextractor--page-root">Source cod
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/extractors/i3splinempeextractor.html b/_modules/graphnet/data/extractors/i3splinempeextractor.html
index 79f5f9457..dd8d383bd 100644
--- a/_modules/graphnet/data/extractors/i3splinempeextractor.html
+++ b/_modules/graphnet/data/extractors/i3splinempeextractor.html
@@ -379,7 +379,7 @@ <h1 id="modules-graphnet-data-extractors-i3splinempeextractor--page-root">Source
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/extractors/i3truthextractor.html b/_modules/graphnet/data/extractors/i3truthextractor.html
index 330ad508b..19d005d0d 100644
--- a/_modules/graphnet/data/extractors/i3truthextractor.html
+++ b/_modules/graphnet/data/extractors/i3truthextractor.html
@@ -781,7 +781,7 @@ <h1 id="modules-graphnet-data-extractors-i3truthextractor--page-root">Source cod
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/extractors/i3tumextractor.html b/_modules/graphnet/data/extractors/i3tumextractor.html
index d22868f07..89adb77c2 100644
--- a/_modules/graphnet/data/extractors/i3tumextractor.html
+++ b/_modules/graphnet/data/extractors/i3tumextractor.html
@@ -382,7 +382,7 @@ <h1 id="modules-graphnet-data-extractors-i3tumextractor--page-root">Source code
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/extractors/utilities/collections.html b/_modules/graphnet/data/extractors/utilities/collections.html
index 11ccbd8f5..14835acc6 100644
--- a/_modules/graphnet/data/extractors/utilities/collections.html
+++ b/_modules/graphnet/data/extractors/utilities/collections.html
@@ -436,7 +436,7 @@ <h1 id="modules-graphnet-data-extractors-utilities-collections--page-root">Sourc
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/extractors/utilities/frames.html b/_modules/graphnet/data/extractors/utilities/frames.html
index 4dfe93249..656cb6033 100644
--- a/_modules/graphnet/data/extractors/utilities/frames.html
+++ b/_modules/graphnet/data/extractors/utilities/frames.html
@@ -439,7 +439,7 @@ <h1 id="modules-graphnet-data-extractors-utilities-frames--page-root">Source cod
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/extractors/utilities/types.html b/_modules/graphnet/data/extractors/utilities/types.html
index f2d4ecd65..60de6d341 100644
--- a/_modules/graphnet/data/extractors/utilities/types.html
+++ b/_modules/graphnet/data/extractors/utilities/types.html
@@ -660,7 +660,7 @@ <h1 id="modules-graphnet-data-extractors-utilities-types--page-root">Source code
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/parquet/parquet_dataconverter.html b/_modules/graphnet/data/parquet/parquet_dataconverter.html
index 7cf5c3507..455153448 100644
--- a/_modules/graphnet/data/parquet/parquet_dataconverter.html
+++ b/_modules/graphnet/data/parquet/parquet_dataconverter.html
@@ -407,7 +407,7 @@ <h1 id="modules-graphnet-data-parquet-parquet-dataconverter--page-root">Source c
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/pipeline.html b/_modules/graphnet/data/pipeline.html
new file mode 100644
index 000000000..90602e63c
--- /dev/null
+++ b/_modules/graphnet/data/pipeline.html
@@ -0,0 +1,593 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.data.pipeline &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" />
+    <script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../about.html" />
+    <link rel="index" title="Index" href="../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/data/pipeline" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.data.pipeline </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../"versions.json"",
+        target_loc = "../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-data-pipeline--page-root">Source code for graphnet.data.pipeline</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Class(es) used for analysis in PISA."""</span>
+
+<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span>
+<span class="kn">import</span> <span class="nn">dill</span>
+<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">reduce</span>
+<span class="kn">import</span> <span class="nn">os</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span>
+
+<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
+<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span>
+<span class="kn">from</span> <span class="nn">pytorch_lightning</span> <span class="kn">import</span> <span class="n">Trainer</span>
+<span class="kn">import</span> <span class="nn">sqlite3</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.data.sqlite.sqlite_utilities</span> <span class="kn">import</span> <span class="n">create_table_and_save_to_sql</span>
+<span class="kn">from</span> <span class="nn">graphnet.training.utils</span> <span class="kn">import</span> <span class="n">get_predictions</span><span class="p">,</span> <span class="n">make_dataloader</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.graphs</span> <span class="kn">import</span> <span class="n">GraphDefinition</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span>
+
+
+<div class="viewcode-block" id="InSQLitePipeline">
+<a class="viewcode-back" href="../../../api/graphnet.data.pipeline.html#graphnet.data.pipeline.InSQLitePipeline">[docs]</a>
+<span class="k">class</span> <span class="nc">InSQLitePipeline</span><span class="p">(</span><span class="n">ABC</span><span class="p">,</span> <span class="n">Logger</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Create a SQLite database for PISA analysis.</span>
+
+<span class="sd">    The database will contain truth and GNN predictions and, if available,</span>
+<span class="sd">    RETRO reconstructions.</span>
+<span class="sd">    """</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">module_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">,</span>
+        <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+        <span class="n">truth</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+        <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
+        <span class="n">retro_table_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"retro"</span><span class="p">,</span>
+        <span class="n">outdir</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
+        <span class="n">n_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
+        <span class="n">pipeline_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"pipeline"</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Initialise the pipeline.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            module_dict: A dictionary with GNN modules from GraphNet. E.g.</span>
+<span class="sd">                {'energy': gnn_module_for_energy_regression}</span>
+<span class="sd">            features: List of input features for the GNN modules.</span>
+<span class="sd">            truth: List of truth for the GNN ModuleList.</span>
+<span class="sd">            device: The device used for computation.</span>
+<span class="sd">            retro_table_name: Name of the retro table for.</span>
+<span class="sd">            outdir: the directory in which the pipeline database will be</span>
+<span class="sd">                stored.</span>
+<span class="sd">            batch_size: Batch size for inference.</span>
+<span class="sd">            n_workers: Number of workers used in dataloading.</span>
+<span class="sd">            pipeline_name: Name of the pipeline. If such a pipeline already</span>
+<span class="sd">                exists, an error will be prompted to avoid overwriting.</span>
+<span class="sd">        """</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_pipeline_name</span> <span class="o">=</span> <span class="n">pipeline_name</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_device</span> <span class="o">=</span> <span class="n">device</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span> <span class="o">=</span> <span class="n">n_workers</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_features</span> <span class="o">=</span> <span class="n">features</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_truth</span> <span class="o">=</span> <span class="n">truth</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_outdir</span> <span class="o">=</span> <span class="n">outdir</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_module_dict</span> <span class="o">=</span> <span class="n">module_dict</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_retro_table_name</span> <span class="o">=</span> <span class="n">retro_table_name</span>
+
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">database</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">pulsemap</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">graph_definition</span><span class="p">:</span> <span class="n">GraphDefinition</span><span class="p">,</span>
+        <span class="n">chunk_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1000000</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Run inference of each field in self._module_dict[target][''].</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            database: Path to database with pulsemap and truth.</span>
+<span class="sd">            pulsemap: Name of pulsemaps.</span>
+<span class="sd">            graph_definition: GraphDefinition for Dataset</span>
+<span class="sd">            chunk_size: database will be sliced in chunks of size `chunk_size`.</span>
+<span class="sd">                Use this parameter to control memory usage.</span>
+<span class="sd">        """</span>
+        <span class="n">outdir</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_outdir</span><span class="p">(</span><span class="n">database</span><span class="p">)</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_device</span><span class="p">,</span> <span class="nb">str</span>
+        <span class="p">):</span>  <span class="c1"># Because pytorch lightning insists on breaking pytorch cuda device naming scheme</span>
+            <span class="n">device</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_device</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
+        <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isdir</span><span class="p">(</span><span class="n">outdir</span><span class="p">):</span>
+            <span class="n">dataloaders</span><span class="p">,</span> <span class="n">event_batches</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_setup_dataloaders</span><span class="p">(</span>
+                <span class="n">graph_definition</span><span class="o">=</span><span class="n">graph_definition</span><span class="p">,</span>
+                <span class="n">chunk_size</span><span class="o">=</span><span class="n">chunk_size</span><span class="p">,</span>
+                <span class="n">db</span><span class="o">=</span><span class="n">database</span><span class="p">,</span>
+                <span class="n">pulsemap</span><span class="o">=</span><span class="n">pulsemap</span><span class="p">,</span>
+                <span class="n">selection</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
+                <span class="n">persistent_workers</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
+            <span class="p">)</span>
+            <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span>
+            <span class="k">for</span> <span class="n">dataloader</span> <span class="ow">in</span> <span class="n">dataloaders</span><span class="p">:</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">"CHUNK </span><span class="si">%s</span><span class="s2"> / </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">dataloaders</span><span class="p">)))</span>
+                <span class="n">df</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_inference</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">dataloader</span><span class="p">)</span>
+                <span class="n">truth</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_truth</span><span class="p">(</span><span class="n">database</span><span class="p">,</span> <span class="n">event_batches</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
+                <span class="n">retro</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_retro</span><span class="p">(</span><span class="n">database</span><span class="p">,</span> <span class="n">event_batches</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_append_to_pipeline</span><span class="p">(</span><span class="n">outdir</span><span class="p">,</span> <span class="n">truth</span><span class="p">,</span> <span class="n">retro</span><span class="p">,</span> <span class="n">df</span><span class="p">)</span>
+                <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="n">outdir</span><span class="p">)</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
+                <span class="s2">"WARNING - Pipeline named </span><span class="si">%s</span><span class="s2"> already exists! </span><span class="se">\n</span><span class="s2"> Please rename pipeline!"</span>
+                <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pipeline_name</span>
+            <span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_setup_dataloaders</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">chunk_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="n">db</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">pulsemap</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">graph_definition</span><span class="p">:</span> <span class="n">GraphDefinition</span><span class="p">,</span>
+        <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">persistent_workers</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">DataLoader</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]]:</span>
+        <span class="k">if</span> <span class="n">selection</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">selection</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_all_event_nos</span><span class="p">(</span><span class="n">db</span><span class="p">)</span>
+        <span class="n">n_chunks</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">selection</span><span class="p">)</span> <span class="o">/</span> <span class="n">chunk_size</span><span class="p">)</span>
+        <span class="n">event_batches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array_split</span><span class="p">(</span><span class="n">selection</span><span class="p">,</span> <span class="n">n_chunks</span><span class="p">)</span>
+        <span class="n">dataloaders</span> <span class="o">=</span> <span class="p">[]</span>
+        <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">event_batches</span><span class="p">:</span>
+            <span class="n">dataloaders</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
+                <span class="n">make_dataloader</span><span class="p">(</span>
+                    <span class="n">db</span><span class="o">=</span><span class="n">db</span><span class="p">,</span>
+                    <span class="n">graph_definition</span><span class="o">=</span><span class="n">graph_definition</span><span class="p">,</span>
+                    <span class="n">pulsemaps</span><span class="o">=</span><span class="n">pulsemap</span><span class="p">,</span>
+                    <span class="n">features</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">,</span>
+                    <span class="n">truth</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_truth</span><span class="p">,</span>
+                    <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_batch_size</span><span class="p">,</span>
+                    <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
+                    <span class="n">selection</span><span class="o">=</span><span class="n">batch</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
+                    <span class="n">num_workers</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span>
+                    <span class="n">persistent_workers</span><span class="o">=</span><span class="n">persistent_workers</span><span class="p">,</span>
+                <span class="p">)</span>
+            <span class="p">)</span>
+        <span class="k">return</span> <span class="n">dataloaders</span><span class="p">,</span> <span class="n">event_batches</span>
+
+    <span class="k">def</span> <span class="nf">_get_all_event_nos</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">db</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
+        <span class="k">with</span> <span class="n">sqlite3</span><span class="o">.</span><span class="n">connect</span><span class="p">(</span><span class="n">db</span><span class="p">)</span> <span class="k">as</span> <span class="n">con</span><span class="p">:</span>
+            <span class="n">query</span> <span class="o">=</span> <span class="s2">"SELECT event_no FROM truth"</span>
+            <span class="n">selection</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_sql</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">con</span><span class="p">)</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
+        <span class="k">return</span> <span class="n">selection</span>
+
+    <span class="k">def</span> <span class="nf">_combine_outputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataframes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">pd</span><span class="o">.</span><span class="n">merge</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">on</span><span class="o">=</span><span class="s2">"event_no"</span><span class="p">),</span> <span class="n">dataframes</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_inference</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">:</span>
+        <span class="n">dataframes</span> <span class="o">=</span> <span class="p">[]</span>
+        <span class="k">for</span> <span class="n">target</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_module_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+            <span class="c1"># dataloader = iter(dataloader)</span>
+            <span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="n">devices</span><span class="o">=</span><span class="p">[</span><span class="n">device</span><span class="p">],</span> <span class="n">accelerator</span><span class="o">=</span><span class="s2">"gpu"</span><span class="p">)</span>
+            <span class="n">model</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_module_dict</span><span class="p">[</span><span class="n">target</span><span class="p">][</span><span class="s2">"path"</span><span class="p">],</span>
+                <span class="n">map_location</span><span class="o">=</span><span class="s2">"cpu"</span><span class="p">,</span>
+                <span class="n">pickle_module</span><span class="o">=</span><span class="n">dill</span><span class="p">,</span>
+            <span class="p">)</span>
+            <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
+            <span class="n">model</span><span class="o">.</span><span class="n">inference</span><span class="p">()</span>
+            <span class="n">results</span> <span class="o">=</span> <span class="n">get_predictions</span><span class="p">(</span>
+                <span class="n">trainer</span><span class="p">,</span>
+                <span class="n">model</span><span class="p">,</span>
+                <span class="n">dataloader</span><span class="p">,</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_module_dict</span><span class="p">[</span><span class="n">target</span><span class="p">][</span><span class="s2">"output_column_names"</span><span class="p">],</span>
+                <span class="n">additional_attributes</span><span class="o">=</span><span class="p">[</span><span class="s2">"event_no"</span><span class="p">],</span>
+            <span class="p">)</span>
+            <span class="n">dataframes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
+                <span class="n">results</span><span class="o">.</span><span class="n">sort_values</span><span class="p">(</span><span class="s2">"event_no"</span><span class="p">)</span><span class="o">.</span><span class="n">reset_index</span><span class="p">(</span><span class="n">drop</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+            <span class="p">)</span>
+            <span class="n">df</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_combine_outputs</span><span class="p">(</span><span class="n">dataframes</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">df</span>
+
+    <span class="k">def</span> <span class="nf">_get_outdir</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">database</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_outdir</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">database_name</span> <span class="o">=</span> <span class="n">database</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"/"</span><span class="p">)[</span><span class="o">-</span><span class="mi">3</span><span class="p">]</span>
+            <span class="n">outdir</span> <span class="o">=</span> <span class="p">(</span>
+                <span class="n">database</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">database_name</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
+                <span class="o">+</span> <span class="n">database_name</span>
+                <span class="o">+</span> <span class="s2">"/pipelines/"</span>
+                <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pipeline_name</span>
+            <span class="p">)</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">outdir</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_outdir</span>
+        <span class="k">return</span> <span class="n">outdir</span>
+
+    <span class="k">def</span> <span class="nf">_get_truth</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">database</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">selection</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">:</span>
+        <span class="k">with</span> <span class="n">sqlite3</span><span class="o">.</span><span class="n">connect</span><span class="p">(</span><span class="n">database</span><span class="p">)</span> <span class="k">as</span> <span class="n">con</span><span class="p">:</span>
+            <span class="n">query</span> <span class="o">=</span> <span class="s2">"SELECT * FROM truth WHERE event_no in </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span>
+                <span class="nb">tuple</span><span class="p">(</span><span class="n">selection</span><span class="p">)</span>
+            <span class="p">)</span>
+            <span class="n">truth</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_sql</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">con</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">truth</span>
+
+    <span class="k">def</span> <span class="nf">_get_retro</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">database</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">selection</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">:</span>
+        <span class="k">try</span><span class="p">:</span>
+            <span class="k">with</span> <span class="n">sqlite3</span><span class="o">.</span><span class="n">connect</span><span class="p">(</span><span class="n">database</span><span class="p">)</span> <span class="k">as</span> <span class="n">con</span><span class="p">:</span>
+                <span class="n">query</span> <span class="o">=</span> <span class="s2">"SELECT * FROM </span><span class="si">%s</span><span class="s2"> WHERE event_no in </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span>
+                    <span class="bp">self</span><span class="o">.</span><span class="n">_retro_table_name</span><span class="p">,</span>
+                    <span class="nb">str</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">selection</span><span class="p">)),</span>
+                <span class="p">)</span>
+                <span class="n">retro</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_sql</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">con</span><span class="p">)</span>
+            <span class="k">return</span> <span class="n">retro</span>
+        <span class="k">except</span><span class="p">:</span>  <span class="c1"># noqa: E722</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">"</span><span class="si">%s</span><span class="s2"> table does not exist"</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">_retro_table_name</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_append_to_pipeline</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">outdir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">truth</span><span class="p">:</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">,</span>
+        <span class="n">retro</span><span class="p">:</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">,</span>
+        <span class="n">df</span><span class="p">:</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">outdir</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+        <span class="n">pipeline_database</span> <span class="o">=</span> <span class="n">outdir</span> <span class="o">+</span> <span class="s2">"/</span><span class="si">%s</span><span class="s2">.db"</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pipeline_name</span>
+        <span class="n">create_table_and_save_to_sql</span><span class="p">(</span><span class="n">df</span><span class="p">,</span> <span class="s2">"reconstruction"</span><span class="p">,</span> <span class="n">pipeline_database</span><span class="p">)</span>
+        <span class="n">create_table_and_save_to_sql</span><span class="p">(</span><span class="n">truth</span><span class="p">,</span> <span class="s2">"truth"</span><span class="p">,</span> <span class="n">pipeline_database</span><span class="p">)</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">retro</span><span class="p">,</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">):</span>
+            <span class="n">create_table_and_save_to_sql</span><span class="p">(</span>
+                <span class="n">retro</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_retro_table_name</span><span class="p">,</span> <span class="n">pipeline_database</span>
+            <span class="p">)</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/data/sqlite/sqlite_dataconverter.html b/_modules/graphnet/data/sqlite/sqlite_dataconverter.html
index a067a451a..69fa21bbe 100644
--- a/_modules/graphnet/data/sqlite/sqlite_dataconverter.html
+++ b/_modules/graphnet/data/sqlite/sqlite_dataconverter.html
@@ -716,7 +716,7 @@ <h1 id="modules-graphnet-data-sqlite-sqlite-dataconverter--page-root">Source cod
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/sqlite/sqlite_utilities.html b/_modules/graphnet/data/sqlite/sqlite_utilities.html
index 29d4e537d..a910e007d 100644
--- a/_modules/graphnet/data/sqlite/sqlite_utilities.html
+++ b/_modules/graphnet/data/sqlite/sqlite_utilities.html
@@ -515,7 +515,7 @@ <h1 id="modules-graphnet-data-sqlite-sqlite-utilities--page-root">Source code fo
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/utilities/parquet_to_sqlite.html b/_modules/graphnet/data/utilities/parquet_to_sqlite.html
index e645a35f4..f3944caa0 100644
--- a/_modules/graphnet/data/utilities/parquet_to_sqlite.html
+++ b/_modules/graphnet/data/utilities/parquet_to_sqlite.html
@@ -533,7 +533,7 @@ <h1 id="modules-graphnet-data-utilities-parquet-to-sqlite--page-root">Source cod
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/utilities/random.html b/_modules/graphnet/data/utilities/random.html
index f297fc67c..eb3a7f85c 100644
--- a/_modules/graphnet/data/utilities/random.html
+++ b/_modules/graphnet/data/utilities/random.html
@@ -375,7 +375,7 @@ <h1 id="modules-graphnet-data-utilities-random--page-root">Source code for graph
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/data/utilities/string_selection_resolver.html b/_modules/graphnet/data/utilities/string_selection_resolver.html
index bd7909f33..361902f54 100644
--- a/_modules/graphnet/data/utilities/string_selection_resolver.html
+++ b/_modules/graphnet/data/utilities/string_selection_resolver.html
@@ -676,7 +676,7 @@ <h1 id="modules-graphnet-data-utilities-string-selection-resolver--page-root">So
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/deployment/i3modules/graphnet_module.html b/_modules/graphnet/deployment/i3modules/graphnet_module.html
new file mode 100644
index 000000000..994461c04
--- /dev/null
+++ b/_modules/graphnet/deployment/i3modules/graphnet_module.html
@@ -0,0 +1,817 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.deployment.i3modules.graphnet_module &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/deployment/i3modules/graphnet_module" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.deployment.i3modules.graphnet_module </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-deployment-i3modules-graphnet-module--page-root">Source code for graphnet.deployment.i3modules.graphnet_module</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Class(es) for deploying GraphNeT models in icetray as I3Modules."""</span>
+<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">TYPE_CHECKING</span><span class="p">,</span> <span class="n">Any</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Optional</span>
+
+<span class="kn">import</span> <span class="nn">dill</span>
+<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.data.extractors</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">I3FeatureExtractor</span><span class="p">,</span>
+    <span class="n">I3FeatureExtractorIceCubeUpgrade</span><span class="p">,</span>
+<span class="p">)</span>
+<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span><span class="p">,</span> <span class="n">StandardModel</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.graphs</span> <span class="kn">import</span> <span class="n">GraphDefinition</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.imports</span> <span class="kn">import</span> <span class="n">has_icecube_package</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">ModelConfig</span>
+
+<span class="k">if</span> <span class="n">has_icecube_package</span><span class="p">()</span> <span class="ow">or</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span>
+    <span class="kn">from</span> <span class="nn">icecube.icetray</span> <span class="kn">import</span> <span class="p">(</span>
+        <span class="n">I3Module</span><span class="p">,</span>
+        <span class="n">I3Frame</span><span class="p">,</span>
+    <span class="p">)</span>  <span class="c1"># pyright: reportMissingImports=false</span>
+    <span class="kn">from</span> <span class="nn">icecube.dataclasses</span> <span class="kn">import</span> <span class="p">(</span>
+        <span class="n">I3Double</span><span class="p">,</span>
+        <span class="n">I3MapKeyVectorDouble</span><span class="p">,</span>
+    <span class="p">)</span>  <span class="c1"># pyright: reportMissingImports=false</span>
+    <span class="kn">from</span> <span class="nn">icecube</span> <span class="kn">import</span> <span class="n">dataclasses</span><span class="p">,</span> <span class="n">dataio</span><span class="p">,</span> <span class="n">icetray</span>
+
+
+<div class="viewcode-block" id="GraphNeTI3Module">
+<a class="viewcode-back" href="../../../../api/graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module">[docs]</a>
+<span class="k">class</span> <span class="nc">GraphNeTI3Module</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Base I3 Module for GraphNeT.</span>
+
+<span class="sd">    Contains methods for extracting pulsemaps, producing graphs and writing to</span>
+<span class="sd">    frames.</span>
+<span class="sd">    """</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">graph_definition</span><span class="p">:</span> <span class="n">GraphDefinition</span><span class="p">,</span>
+        <span class="n">pulsemap</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+        <span class="n">pulsemap_extractor</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span>
+            <span class="n">List</span><span class="p">[</span><span class="n">I3FeatureExtractor</span><span class="p">],</span> <span class="n">I3FeatureExtractor</span>
+        <span class="p">],</span>
+        <span class="n">gcd_file</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""I3Module Constructor.</span>
+
+<span class="sd">        Arguments:</span>
+<span class="sd">            graph_definition: An instance of GraphDefinition.  E.g. KNNGraph.</span>
+<span class="sd">            pulsemap: the pulse map on which the module functions</span>
+<span class="sd">            features: the features that is used from the pulse map.</span>
+<span class="sd">                      E.g. [dom_x, dom_y, dom_z, charge]</span>
+<span class="sd">            pulsemap_extractor: The I3FeatureExtractor used to extract the</span>
+<span class="sd">                                pulsemap from the I3Frames</span>
+<span class="sd">            gcd_file: Path to the associated gcd-file.</span>
+<span class="sd">        """</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">graph_definition</span><span class="p">,</span> <span class="n">GraphDefinition</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_graph_definition</span> <span class="o">=</span> <span class="n">graph_definition</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_pulsemap</span> <span class="o">=</span> <span class="n">pulsemap</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_features</span> <span class="o">=</span> <span class="n">features</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">gcd_file</span><span class="p">,</span> <span class="nb">str</span><span class="p">),</span> <span class="s2">"gcd_file must be string"</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_gcd_file</span> <span class="o">=</span> <span class="n">gcd_file</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pulsemap_extractor</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_i3_extractors</span> <span class="o">=</span> <span class="n">pulsemap_extractor</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_i3_extractors</span> <span class="o">=</span> <span class="p">[</span><span class="n">pulsemap_extractor</span><span class="p">]</span>
+
+        <span class="k">for</span> <span class="n">i3_extractor</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_i3_extractors</span><span class="p">:</span>
+            <span class="n">i3_extractor</span><span class="o">.</span><span class="n">set_files</span><span class="p">(</span><span class="n">i3_file</span><span class="o">=</span><span class="s2">""</span><span class="p">,</span> <span class="n">gcd_file</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_gcd_file</span><span class="p">)</span>
+
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Define here how the module acts on the frame.</span>
+
+<span class="sd">        Must return True if successful.</span>
+
+<span class="sd">        Return True # SUPER IMPORTANT</span>
+<span class="sd">        """</span>
+
+    <span class="k">def</span> <span class="nf">_make_graph</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>  <span class="c1"># py-l-i-n-t-:- -d-i-s-able=invalid-name</span>
+<span class="w">        </span><span class="sd">"""Process Physics I3Frame into graph."""</span>
+        <span class="c1"># Extract features</span>
+        <span class="n">node_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_extract_feature_array_from_frame</span><span class="p">(</span><span class="n">frame</span><span class="p">)</span>
+        <span class="c1"># Prepare graph data</span>
+        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">node_features</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
+            <span class="n">data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_graph_definition</span><span class="p">(</span>
+                <span class="n">node_features</span><span class="o">=</span><span class="n">node_features</span><span class="p">,</span>
+                <span class="n">node_feature_names</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">,</span>
+            <span class="p">)</span>
+            <span class="k">return</span> <span class="n">Batch</span><span class="o">.</span><span class="n">from_data_list</span><span class="p">([</span><span class="n">data</span><span class="p">])</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">return</span> <span class="kc">None</span>
+
+    <span class="k">def</span> <span class="nf">_extract_feature_array_from_frame</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Apply the I3FeatureExtractors to the I3Frame.</span>
+
+<span class="sd">        Arguments:</span>
+<span class="sd">            frame: Physics I3Frame (PFrame)</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            array with pulsemap</span>
+<span class="sd">        """</span>
+        <span class="n">features</span> <span class="o">=</span> <span class="kc">None</span>
+        <span class="k">for</span> <span class="n">i3extractor</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_i3_extractors</span><span class="p">:</span>
+            <span class="n">feature_dict</span> <span class="o">=</span> <span class="n">i3extractor</span><span class="p">(</span><span class="n">frame</span><span class="p">)</span>
+            <span class="n">features_pulsemap</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span>
+                <span class="p">[</span><span class="n">feature_dict</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">]</span>
+            <span class="p">)</span><span class="o">.</span><span class="n">T</span>
+            <span class="k">if</span> <span class="n">features</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+                <span class="n">features</span> <span class="o">=</span> <span class="n">features_pulsemap</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="n">features</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
+                    <span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">features_pulsemap</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span>
+                <span class="p">)</span>
+        <span class="k">return</span> <span class="n">features</span>
+
+    <span class="k">def</span> <span class="nf">_add_to_frame</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">I3Frame</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Add every field in data to I3Frame.</span>
+
+<span class="sd">        Arguments:</span>
+<span class="sd">            frame: I3Frame (physics)</span>
+<span class="sd">            data: Dictionary containing content that will be written to frame.</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            frame: Same I3Frame as input, but with the new entries</span>
+<span class="sd">        """</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span>
+            <span class="n">data</span><span class="p">,</span> <span class="nb">dict</span>
+        <span class="p">),</span> <span class="sa">f</span><span class="s2">"data must be of type dict. Got </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">data</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
+        <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">data</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+            <span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frame</span><span class="p">:</span>
+                <span class="n">frame</span><span class="o">.</span><span class="n">Put</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="n">key</span><span class="p">])</span>
+        <span class="k">return</span> <span class="n">frame</span></div>
+
+
+
+<div class="viewcode-block" id="I3InferenceModule">
+<a class="viewcode-back" href="../../../../api/graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule">[docs]</a>
+<span class="k">class</span> <span class="nc">I3InferenceModule</span><span class="p">(</span><span class="n">GraphNeTI3Module</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""General class for inference on i3 frames."""</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">pulsemap</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+        <span class="n">pulsemap_extractor</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span>
+            <span class="n">List</span><span class="p">[</span><span class="n">I3FeatureExtractor</span><span class="p">],</span> <span class="n">I3FeatureExtractor</span>
+        <span class="p">],</span>
+        <span class="n">model_config</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">ModelConfig</span><span class="p">,</span> <span class="nb">str</span><span class="p">],</span>
+        <span class="n">state_dict</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">model_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">gcd_file</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">prediction_columns</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""General class for inference on I3Frames (physics).</span>
+
+<span class="sd">        Arguments:</span>
+<span class="sd">            pulsemap: the pulsmap that the model is expecting as input.</span>
+<span class="sd">            features: the features of the pulsemap that the model is expecting.</span>
+<span class="sd">            pulsemap_extractor: The extractor used to extract the pulsemap.</span>
+<span class="sd">            model_config: The ModelConfig (or path to it) that summarizes the</span>
+<span class="sd">                            model used for inference.</span>
+<span class="sd">            state_dict: Path to state_dict containing the learned weights.</span>
+<span class="sd">            model_name: The name used for the model. Will help define the</span>
+<span class="sd">                        named entry in the I3Frame. E.g. "dynedge".</span>
+<span class="sd">            gcd_file: path to associated gcd file.</span>
+<span class="sd">            prediction_columns: column names for the predictions of the model.</span>
+<span class="sd">                               Will help define the named entry in the I3Frame.</span>
+<span class="sd">                                E.g. ['energy_reco']. Optional.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Construct model &amp; load weights</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span> <span class="n">trust</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="p">)</span>
+
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
+            <span class="n">pulsemap</span><span class="o">=</span><span class="n">pulsemap</span><span class="p">,</span>
+            <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">,</span>
+            <span class="n">pulsemap_extractor</span><span class="o">=</span><span class="n">pulsemap_extractor</span><span class="p">,</span>
+            <span class="n">gcd_file</span><span class="o">=</span><span class="n">gcd_file</span><span class="p">,</span>
+            <span class="n">graph_definition</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">_graph_definition</span><span class="p">,</span>
+        <span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">inference</span><span class="p">()</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">"cpu"</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">prediction_columns</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">prediction_columns</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span> <span class="o">=</span> <span class="p">[</span><span class="n">prediction_columns</span><span class="p">]</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span> <span class="o">=</span> <span class="n">prediction_columns</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">prediction_labels</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">model_name</span> <span class="o">=</span> <span class="n">model_name</span>
+
+    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Write predictions from model to frame."""</span>
+        <span class="c1"># inference</span>
+        <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_make_graph</span><span class="p">(</span><span class="n">frame</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">graph</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">predictions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_inference</span><span class="p">(</span><span class="n">graph</span><span class="p">)</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">predictions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span>
+                <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">],</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span><span class="p">)</span>
+            <span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span><span class="p">))</span>
+        <span class="c1"># Check dimensions of predictions and prediction columns</span>
+        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
+            <span class="n">dim</span> <span class="o">=</span> <span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">dim</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="n">dim</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span>
+        <span class="p">),</span> <span class="sa">f</span><span class="s2">"""predictions have shape </span><span class="si">{</span><span class="n">dim</span><span class="si">}</span><span class="s2"> but </span><span class="se">\n</span>
+<span class="s2">            prediction columns have [</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span><span class="si">}</span><span class="s2">]"""</span>
+
+        <span class="c1"># Build Dictionary of predictions</span>
+        <span class="n">data</span> <span class="o">=</span> <span class="p">{}</span>
+        <span class="k">assert</span> <span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span>
+        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">dim</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim</span><span class="p">)):</span>
+            <span class="k">try</span><span class="p">:</span>
+                <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions</span><span class="p">[:,</span> <span class="n">i</span><span class="p">])</span> <span class="o">==</span> <span class="mi">1</span>
+                <span class="n">data</span><span class="p">[</span>
+                    <span class="bp">self</span><span class="o">.</span><span class="n">model_name</span> <span class="o">+</span> <span class="s2">"_"</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
+                <span class="p">]</span> <span class="o">=</span> <span class="n">I3Double</span><span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="n">predictions</span><span class="p">[:,</span> <span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]))</span>
+            <span class="k">except</span> <span class="ne">IndexError</span><span class="p">:</span>
+                <span class="n">data</span><span class="p">[</span>
+                    <span class="bp">self</span><span class="o">.</span><span class="n">model_name</span> <span class="o">+</span> <span class="s2">"_"</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
+                <span class="p">]</span> <span class="o">=</span> <span class="n">I3Double</span><span class="p">(</span><span class="n">predictions</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
+
+        <span class="c1"># Submission methods</span>
+        <span class="n">frame</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_to_frame</span><span class="p">(</span><span class="n">frame</span><span class="o">=</span><span class="n">frame</span><span class="p">,</span> <span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">)</span>
+        <span class="k">return</span> <span class="kc">True</span>
+
+    <span class="k">def</span> <span class="nf">_inference</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
+        <span class="c1"># Perform inference</span>
+        <span class="n">task_predictions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="p">(</span>
+            <span class="nb">len</span><span class="p">(</span><span class="n">task_predictions</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span>
+        <span class="p">),</span> <span class="sa">f</span><span class="s2">"""This method assumes a single task. </span><span class="se">\n</span>
+<span class="s2">               Got </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">task_predictions</span><span class="p">)</span><span class="si">}</span><span class="s2"> tasks."""</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span></div>
+
+
+
+<div class="viewcode-block" id="I3PulseCleanerModule">
+<a class="viewcode-back" href="../../../../api/graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule">[docs]</a>
+<span class="k">class</span> <span class="nc">I3PulseCleanerModule</span><span class="p">(</span><span class="n">I3InferenceModule</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""A specialized module for pulse cleaning.</span>
+
+<span class="sd">    It is assumed that the model provided has been trained for this.</span>
+<span class="sd">    """</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">pulsemap</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+        <span class="n">pulsemap_extractor</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span>
+            <span class="n">List</span><span class="p">[</span><span class="n">I3FeatureExtractor</span><span class="p">],</span> <span class="n">I3FeatureExtractor</span>
+        <span class="p">],</span>
+        <span class="n">model_config</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">state_dict</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">model_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="o">*</span><span class="p">,</span>
+        <span class="n">gcd_file</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.7</span><span class="p">,</span>
+        <span class="n">discard_empty_events</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">prediction_columns</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""General class for inference on I3Frames (physics).</span>
+
+<span class="sd">        Arguments:</span>
+<span class="sd">            pulsemap: the pulsmap that the model is expecting as input</span>
+<span class="sd">                     (the one that is being cleaned).</span>
+<span class="sd">            features: the features of the pulsemap that the model is expecting.</span>
+<span class="sd">            pulsemap_extractor: The extractor used to extract the pulsemap.</span>
+<span class="sd">            model_config: The ModelConfig (or path to it) that summarizes the</span>
+<span class="sd">                            model used for inference.</span>
+<span class="sd">            state_dict: Path to state_dict containing the learned weights.</span>
+<span class="sd">            model_name: The name used for the model. Will help define the named</span>
+<span class="sd">                        entry in the I3Frame. E.g. "dynedge".</span>
+<span class="sd">            gcd_file: path to associated gcd file.</span>
+<span class="sd">            threshold: the threshold for being considered a positive case.</span>
+<span class="sd">                        E.g., predictions &gt;= threshold will be considered</span>
+<span class="sd">                        to be signal, all else noise.</span>
+<span class="sd">            discard_empty_events: When true, this flag will eliminate events</span>
+<span class="sd">                            whose cleaned pulse series are empty. Can be used</span>
+<span class="sd">                            to speed up processing especially for noise</span>
+<span class="sd">                            simulation, since it will not do any writing or</span>
+<span class="sd">                            further calculations.</span>
+<span class="sd">            prediction_columns: column names for the predictions of the model.</span>
+<span class="sd">                            Will help define the named entry in the I3Frame.</span>
+<span class="sd">                            E.g. ['energy_reco']. Optional.</span>
+<span class="sd">        """</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
+            <span class="n">pulsemap</span><span class="o">=</span><span class="n">pulsemap</span><span class="p">,</span>
+            <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">,</span>
+            <span class="n">pulsemap_extractor</span><span class="o">=</span><span class="n">pulsemap_extractor</span><span class="p">,</span>
+            <span class="n">model_config</span><span class="o">=</span><span class="n">model_config</span><span class="p">,</span>
+            <span class="n">state_dict</span><span class="o">=</span><span class="n">state_dict</span><span class="p">,</span>
+            <span class="n">model_name</span><span class="o">=</span><span class="n">model_name</span><span class="p">,</span>
+            <span class="n">prediction_columns</span><span class="o">=</span><span class="n">prediction_columns</span><span class="p">,</span>
+            <span class="n">gcd_file</span><span class="o">=</span><span class="n">gcd_file</span><span class="p">,</span>
+        <span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_threshold</span> <span class="o">=</span> <span class="n">threshold</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_predictions_key</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">pulsemap</span><span class="si">}</span><span class="s2">_</span><span class="si">{</span><span class="n">model_name</span><span class="si">}</span><span class="s2">_Predictions"</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">pulsemap</span><span class="si">}</span><span class="s2">_</span><span class="si">{</span><span class="n">model_name</span><span class="si">}</span><span class="s2">_Pulses"</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_discard_empty_events</span> <span class="o">=</span> <span class="n">discard_empty_events</span>
+
+    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Add a cleaned pulsemap to frame."""</span>
+        <span class="c1"># inference</span>
+        <span class="n">gcd_file</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_gcd_file</span>
+        <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_make_graph</span><span class="p">(</span><span class="n">frame</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">graph</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>  <span class="c1"># If there is no pulses to clean</span>
+            <span class="k">return</span> <span class="kc">False</span>
+        <span class="n">predictions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_inference</span><span class="p">(</span><span class="n">graph</span><span class="p">)</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_discard_empty_events</span><span class="p">:</span>
+            <span class="k">if</span> <span class="nb">sum</span><span class="p">(</span><span class="n">predictions</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">_threshold</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
+                <span class="k">return</span> <span class="kc">False</span>
+
+        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
+            <span class="n">predictions</span> <span class="o">=</span> <span class="n">predictions</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
+
+        <span class="k">assert</span> <span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span>
+
+        <span class="c1"># Build Dictionary of predictions</span>
+        <span class="n">data</span> <span class="o">=</span> <span class="p">{}</span>
+
+        <span class="n">predictions_map</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_construct_prediction_map</span><span class="p">(</span>
+            <span class="n">frame</span><span class="o">=</span><span class="n">frame</span><span class="p">,</span> <span class="n">predictions</span><span class="o">=</span><span class="n">predictions</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Adds the raw predictions to dictionary</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_predictions_key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frame</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+            <span class="n">data</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_predictions_key</span><span class="p">]</span> <span class="o">=</span> <span class="n">predictions_map</span>
+
+        <span class="c1"># Create a pulse map mask, indicating the pulses that are over</span>
+        <span class="c1"># threshold (e.g. identified as signal) and therefore should be kept</span>
+        <span class="c1"># Using a lambda function to evaluate which pulses to keep by</span>
+        <span class="c1"># checking the prediction for each pulse</span>
+        <span class="c1"># (Adds the actual pulsemap to dictionary)</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frame</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+            <span class="n">data</span><span class="p">[</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span>
+            <span class="p">]</span> <span class="o">=</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">I3RecoPulseSeriesMapMask</span><span class="p">(</span>
+                <span class="n">frame</span><span class="p">,</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_pulsemap</span><span class="p">,</span>
+                <span class="k">lambda</span> <span class="n">om_key</span><span class="p">,</span> <span class="n">index</span><span class="p">,</span> <span class="n">pulse</span><span class="p">:</span> <span class="n">predictions_map</span><span class="p">[</span><span class="n">om_key</span><span class="p">][</span><span class="n">index</span><span class="p">]</span>
+                <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_threshold</span><span class="p">,</span>
+            <span class="p">)</span>
+
+        <span class="c1"># Submit predictions and general pulsemap</span>
+        <span class="n">frame</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_to_frame</span><span class="p">(</span><span class="n">frame</span><span class="o">=</span><span class="n">frame</span><span class="p">,</span> <span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">)</span>
+        <span class="n">data</span> <span class="o">=</span> <span class="p">{}</span>
+        <span class="c1"># Adds an additional pulsemap for each DOM type</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_i3_extractors</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">I3FeatureExtractorIceCubeUpgrade</span>
+        <span class="p">):</span>
+            <span class="n">mDOMMap</span><span class="p">,</span> <span class="n">DEggMap</span><span class="p">,</span> <span class="n">IceCubeMap</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_split_pulsemap_in_dom_types</span><span class="p">(</span>
+                <span class="n">frame</span><span class="o">=</span><span class="n">frame</span><span class="p">,</span> <span class="n">gcd_file</span><span class="o">=</span><span class="n">gcd_file</span>
+            <span class="p">)</span>
+
+            <span class="k">if</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span><span class="si">}</span><span class="s2">_mDOMs_Only"</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frame</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+                <span class="n">data</span><span class="p">[</span>
+                    <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span><span class="si">}</span><span class="s2">_mDOMs_Only"</span>
+                <span class="p">]</span> <span class="o">=</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">I3RecoPulseSeriesMap</span><span class="p">(</span><span class="n">mDOMMap</span><span class="p">)</span>
+
+            <span class="k">if</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span><span class="si">}</span><span class="s2">_dEggs_Only"</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frame</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+                <span class="n">data</span><span class="p">[</span>
+                    <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span><span class="si">}</span><span class="s2">_dEggs_Only"</span>
+                <span class="p">]</span> <span class="o">=</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">I3RecoPulseSeriesMap</span><span class="p">(</span><span class="n">DEggMap</span><span class="p">)</span>
+
+            <span class="k">if</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span><span class="si">}</span><span class="s2">_pDOMs_Only"</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frame</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+                <span class="n">data</span><span class="p">[</span>
+                    <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span><span class="si">}</span><span class="s2">_pDOMs_Only"</span>
+                <span class="p">]</span> <span class="o">=</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">I3RecoPulseSeriesMap</span><span class="p">(</span><span class="n">IceCubeMap</span><span class="p">)</span>
+
+        <span class="c1"># Submits the additional pulsemaps to the frame</span>
+        <span class="n">frame</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_to_frame</span><span class="p">(</span><span class="n">frame</span><span class="o">=</span><span class="n">frame</span><span class="p">,</span> <span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">)</span>
+
+        <span class="k">return</span> <span class="kc">True</span>
+
+    <span class="k">def</span> <span class="nf">_split_pulsemap_in_dom_types</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span><span class="p">,</span> <span class="n">gcd_file</span><span class="p">:</span> <span class="n">Any</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="n">Any</span><span class="p">],</span> <span class="n">Dict</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="n">Any</span><span class="p">],</span> <span class="n">Dict</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]:</span>
+<span class="w">        </span><span class="sd">"""Will split the cleaned pulsemap into multiple pulsemaps.</span>
+
+<span class="sd">        Arguments:</span>
+<span class="sd">            frame: I3Frame (physics)</span>
+<span class="sd">            gcd_file: path to associated gcd file</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            mDOMMap, DeGGMap, IceCubeMap</span>
+<span class="sd">        """</span>
+        <span class="n">g</span> <span class="o">=</span> <span class="n">dataio</span><span class="o">.</span><span class="n">I3File</span><span class="p">(</span><span class="n">gcd_file</span><span class="p">)</span>
+        <span class="n">gFrame</span> <span class="o">=</span> <span class="n">g</span><span class="o">.</span><span class="n">pop_frame</span><span class="p">()</span>
+        <span class="k">while</span> <span class="s2">"I3Geometry"</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">gFrame</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+            <span class="n">gFrame</span> <span class="o">=</span> <span class="n">g</span><span class="o">.</span><span class="n">pop_frame</span><span class="p">()</span>
+        <span class="n">omGeoMap</span> <span class="o">=</span> <span class="n">gFrame</span><span class="p">[</span><span class="s2">"I3Geometry"</span><span class="p">]</span><span class="o">.</span><span class="n">omgeo</span>
+
+        <span class="n">mDOMMap</span><span class="p">,</span> <span class="n">DEggMap</span><span class="p">,</span> <span class="n">IceCubeMap</span> <span class="o">=</span> <span class="p">{},</span> <span class="p">{},</span> <span class="p">{}</span>
+        <span class="n">pulses</span> <span class="o">=</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">I3RecoPulseSeriesMap</span><span class="o">.</span><span class="n">from_frame</span><span class="p">(</span>
+            <span class="n">frame</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span>
+        <span class="p">)</span>
+        <span class="k">for</span> <span class="n">P</span> <span class="ow">in</span> <span class="n">pulses</span><span class="p">:</span>
+            <span class="n">om</span> <span class="o">=</span> <span class="n">omGeoMap</span><span class="p">[</span><span class="n">P</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span>
+            <span class="k">if</span> <span class="n">om</span><span class="o">.</span><span class="n">omtype</span> <span class="o">==</span> <span class="mi">130</span><span class="p">:</span>  <span class="c1"># "mDOM"</span>
+                <span class="n">mDOMMap</span><span class="p">[</span><span class="n">P</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="o">=</span> <span class="n">P</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
+            <span class="k">elif</span> <span class="n">om</span><span class="o">.</span><span class="n">omtype</span> <span class="o">==</span> <span class="mi">120</span><span class="p">:</span>  <span class="c1"># "DEgg"</span>
+                <span class="n">DEggMap</span><span class="p">[</span><span class="n">P</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="o">=</span> <span class="n">P</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
+            <span class="k">elif</span> <span class="n">om</span><span class="o">.</span><span class="n">omtype</span> <span class="o">==</span> <span class="mi">20</span><span class="p">:</span>  <span class="c1"># "IceCube / pDOM"</span>
+                <span class="n">IceCubeMap</span><span class="p">[</span><span class="n">P</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="o">=</span> <span class="n">P</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
+        <span class="k">return</span> <span class="n">mDOMMap</span><span class="p">,</span> <span class="n">DEggMap</span><span class="p">,</span> <span class="n">IceCubeMap</span>
+
+    <span class="k">def</span> <span class="nf">_construct_prediction_map</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">I3MapKeyVectorDouble</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Make a pulsemap from predictions (for all OM types).</span>
+
+<span class="sd">        Arguments:</span>
+<span class="sd">            frame: I3Frame (physics)</span>
+<span class="sd">            predictions: predictions from GNN</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            predictions_map: a pulsemap from predictions</span>
+<span class="sd">        """</span>
+        <span class="n">pulsemap</span> <span class="o">=</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">I3RecoPulseSeriesMap</span><span class="o">.</span><span class="n">from_frame</span><span class="p">(</span>
+            <span class="n">frame</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pulsemap</span>
+        <span class="p">)</span>
+
+        <span class="n">idx</span> <span class="o">=</span> <span class="mi">0</span>
+        <span class="n">predictions</span> <span class="o">=</span> <span class="n">predictions</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
+        <span class="n">predictions_map</span> <span class="o">=</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">I3MapKeyVectorDouble</span><span class="p">()</span>
+        <span class="k">for</span> <span class="n">om_key</span><span class="p">,</span> <span class="n">pulses</span> <span class="ow">in</span> <span class="n">pulsemap</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
+            <span class="n">num_pulses</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">pulses</span><span class="p">)</span>
+            <span class="n">predictions_map</span><span class="p">[</span><span class="n">om_key</span><span class="p">]</span> <span class="o">=</span> <span class="n">predictions</span><span class="p">[</span>
+                <span class="n">idx</span> <span class="p">:</span> <span class="n">idx</span> <span class="o">+</span> <span class="n">num_pulses</span>
+            <span class="p">]</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
+            <span class="n">idx</span> <span class="o">+=</span> <span class="n">num_pulses</span>
+
+        <span class="c1"># Checks</span>
+        <span class="k">assert</span> <span class="n">idx</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span>
+            <span class="n">predictions</span>
+        <span class="p">),</span> <span class="s2">"""Not all predictions were mapped to pulses,</span><span class="se">\n</span>
+<span class="s2">            validation of predictions have failed."""</span>
+
+        <span class="k">assert</span> <span class="p">(</span>
+            <span class="n">pulsemap</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="o">==</span> <span class="n">predictions_map</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
+        <span class="p">),</span> <span class="s2">"""Input pulse map and predictions map do </span><span class="se">\n</span>
+<span class="s2">              not contain exactly the same OMs"""</span>
+        <span class="k">return</span> <span class="n">predictions_map</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/coarsening.html b/_modules/graphnet/models/coarsening.html
new file mode 100644
index 000000000..6a260f076
--- /dev/null
+++ b/_modules/graphnet/models/coarsening.html
@@ -0,0 +1,711 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.coarsening &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" />
+    <script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../about.html" />
+    <link rel="index" title="Index" href="../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/coarsening" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.coarsening </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../"versions.json"",
+        target_loc = "../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-coarsening--page-root">Source code for graphnet.models.coarsening</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Class(es) for coarsening operations (i.e., clustering, or local pooling)."""</span>
+
+<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span>
+<span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">Tensor</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span>
+<span class="kn">from</span> <span class="nn">sklearn.cluster</span> <span class="kn">import</span> <span class="n">DBSCAN</span>
+
+<span class="c1"># from torch_geometric.utils import unbatch_edge_index</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.components.pool</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">group_by</span><span class="p">,</span>
+    <span class="n">avg_pool</span><span class="p">,</span>
+    <span class="n">max_pool</span><span class="p">,</span>
+    <span class="n">min_pool</span><span class="p">,</span>
+    <span class="n">sum_pool</span><span class="p">,</span>
+    <span class="n">avg_pool_x</span><span class="p">,</span>
+    <span class="n">max_pool_x</span><span class="p">,</span>
+    <span class="n">min_pool_x</span><span class="p">,</span>
+    <span class="n">sum_pool_x</span><span class="p">,</span>
+    <span class="n">std_pool_x</span><span class="p">,</span>
+<span class="p">)</span>
+<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span>
+
+<span class="c1"># Utility method(s)</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.utils</span> <span class="kn">import</span> <span class="n">degree</span>
+
+<span class="c1"># NOTE: From [https://github.com/pyg-team/pytorch_geometric/pull/4903]</span>
+<span class="c1"># TODO:  Remove once bumping to torch_geometric&gt;=2.1.0</span>
+<span class="c1">#       See [https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md]</span>
+
+
+<div class="viewcode-block" id="unbatch_edge_index">
+<a class="viewcode-back" href="../../../api/graphnet.models.coarsening.html#graphnet.models.coarsening.unbatch_edge_index">[docs]</a>
+<span class="k">def</span> <span class="nf">unbatch_edge_index</span><span class="p">(</span><span class="n">edge_index</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
+    <span class="c1"># noqa: D401</span>
+<span class="w">    </span><span class="sa">r</span><span class="sd">"""Splits the :obj:`edge_index` according to a :obj:`batch` vector.</span>
+
+<span class="sd">    Args:</span>
+<span class="sd">        edge_index (Tensor): The edge_index tensor. Must be ordered.</span>
+<span class="sd">        batch (LongTensor): The batch vector</span>
+<span class="sd">            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each</span>
+<span class="sd">            node to a specific example. Must be ordered.</span>
+<span class="sd">    :rtype: :class:`List[Tensor]`</span>
+<span class="sd">    """</span>
+    <span class="n">deg</span> <span class="o">=</span> <span class="n">degree</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
+    <span class="n">ptr</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">deg</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">deg</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
+
+    <span class="n">edge_batch</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">edge_index</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span>
+    <span class="n">edge_index</span> <span class="o">=</span> <span class="n">edge_index</span> <span class="o">-</span> <span class="n">ptr</span><span class="p">[</span><span class="n">edge_batch</span><span class="p">]</span>
+    <span class="n">sizes</span> <span class="o">=</span> <span class="n">degree</span><span class="p">(</span><span class="n">edge_batch</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
+    <span class="k">return</span> <span class="n">edge_index</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></div>
+
+
+
+<div class="viewcode-block" id="Coarsening">
+<a class="viewcode-back" href="../../../api/graphnet.models.coarsening.html#graphnet.models.coarsening.Coarsening">[docs]</a>
+<span class="k">class</span> <span class="nc">Coarsening</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Base class for coarsening operations."""</span>
+
+    <span class="c1"># Class variables</span>
+    <span class="n">reduce_options</span> <span class="o">=</span> <span class="p">{</span>
+        <span class="s2">"avg"</span><span class="p">:</span> <span class="p">(</span><span class="n">avg_pool</span><span class="p">,</span> <span class="n">avg_pool_x</span><span class="p">),</span>
+        <span class="s2">"min"</span><span class="p">:</span> <span class="p">(</span><span class="n">min_pool</span><span class="p">,</span> <span class="n">min_pool_x</span><span class="p">),</span>
+        <span class="s2">"max"</span><span class="p">:</span> <span class="p">(</span><span class="n">max_pool</span><span class="p">,</span> <span class="n">max_pool_x</span><span class="p">),</span>
+        <span class="s2">"sum"</span><span class="p">:</span> <span class="p">(</span><span class="n">sum_pool</span><span class="p">,</span> <span class="n">sum_pool_x</span><span class="p">),</span>
+    <span class="p">}</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">reduce</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"avg"</span><span class="p">,</span>
+        <span class="n">transfer_attributes</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct `Coarsening`."""</span>
+        <span class="k">assert</span> <span class="n">reduce</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduce_options</span>
+
+        <span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_reduce_method</span><span class="p">,</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_attribute_reduce_method</span><span class="p">,</span>
+        <span class="p">)</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduce_options</span><span class="p">[</span><span class="n">reduce</span><span class="p">]</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_do_transfer_attributes</span> <span class="o">=</span> <span class="n">transfer_attributes</span>
+
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
+
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">_perform_clustering</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">LongTensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Cluster nodes in `data` by assigning a cluster index to each."""</span>
+
+    <span class="k">def</span> <span class="nf">_additional_features</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Batch</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Perform additional poolings of feature tensor `x` on `data`.</span>
+
+<span class="sd">        By default the nominal `pooling_method` is used for features as well.</span>
+<span class="sd">        This method can be overwritten for bespoke coarsening operations.</span>
+<span class="sd">        """</span>
+
+    <span class="k">def</span> <span class="nf">_transfer_attributes</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">original_data</span><span class="p">:</span> <span class="n">Batch</span><span class="p">,</span> <span class="n">pooled_data</span><span class="p">:</span> <span class="n">Batch</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Batch</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Transfer attributes on `original_data` to `pooled_data`."""</span>
+        <span class="c1"># Check(s)</span>
+        <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_do_transfer_attributes</span><span class="p">:</span>
+            <span class="k">return</span> <span class="n">pooled_data</span>
+
+        <span class="n">attributes</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">original_data</span><span class="o">.</span><span class="n">_store</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
+        <span class="n">batch</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LongTensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">original_data</span><span class="o">.</span><span class="n">batch</span>
+        <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="n">attr</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">attributes</span><span class="p">):</span>
+            <span class="k">if</span> <span class="n">attr</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">pooled_data</span><span class="o">.</span><span class="n">_store</span><span class="p">:</span>
+                <span class="n">values</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">original_data</span><span class="p">,</span> <span class="n">attr</span><span class="p">)</span>
+
+                <span class="n">attr_is_node_level_tensor</span> <span class="o">=</span> <span class="kc">False</span>
+                <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">values</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
+                    <span class="k">if</span> <span class="n">batch</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+                        <span class="n">attr_is_node_level_tensor</span> <span class="o">=</span> <span class="p">(</span>
+                            <span class="n">values</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">values</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span>
+                        <span class="p">)</span>
+                    <span class="k">else</span><span class="p">:</span>
+                        <span class="n">attr_is_node_level_tensor</span> <span class="o">=</span> <span class="p">(</span>
+                            <span class="n">values</span><span class="o">.</span><span class="n">size</span><span class="p">()</span> <span class="o">==</span> <span class="n">original_data</span><span class="o">.</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
+                        <span class="p">)</span>
+
+                <span class="k">if</span> <span class="n">attr_is_node_level_tensor</span><span class="p">:</span>
+                    <span class="n">values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_attribute_reduce_method</span><span class="p">(</span>
+                        <span class="n">cluster</span><span class="p">,</span>
+                        <span class="n">values</span><span class="p">,</span>
+                        <span class="n">batch</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">values</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
+                    <span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
+
+                <span class="nb">setattr</span><span class="p">(</span><span class="n">pooled_data</span><span class="p">,</span> <span class="n">attr</span><span class="p">,</span> <span class="n">values</span><span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">pooled_data</span>
+
+<div class="viewcode-block" id="Coarsening.forward">
+<a class="viewcode-back" href="../../../api/graphnet.models.coarsening.html#graphnet.models.coarsening.Coarsening.forward">[docs]</a>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Perform coarsening operation."""</span>
+        <span class="c1"># Get tensor of cluster indices for each node.</span>
+        <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_perform_clustering</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
+
+        <span class="c1"># Check whether a graph has already been built. Otherwise, set a dummy</span>
+        <span class="c1"># connectivity, as this is required by pooling functions.</span>
+        <span class="n">edge_index</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span>
+        <span class="k">if</span> <span class="n">edge_index</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[]],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
+
+        <span class="c1"># Pool `data` object, including `x`, `batch`. and `edge_index`.</span>
+        <span class="n">pooled_data</span><span class="p">:</span> <span class="n">Batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_reduce_method</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span>
+
+        <span class="c1"># Optionally overwrite feature tensor</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_additional_features</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">x</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">pooled_data</span><span class="o">.</span><span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
+                <span class="p">(</span>
+                    <span class="n">pooled_data</span><span class="o">.</span><span class="n">x</span><span class="p">,</span>
+                    <span class="n">x</span><span class="p">,</span>
+                <span class="p">),</span>
+                <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+            <span class="p">)</span>
+
+        <span class="c1"># Reset `edge_index` if necessary.</span>
+        <span class="k">if</span> <span class="n">edge_index</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span> <span class="o">=</span> <span class="n">edge_index</span>
+            <span class="n">pooled_data</span><span class="o">.</span><span class="n">edge_index</span> <span class="o">=</span> <span class="n">edge_index</span>
+
+        <span class="c1"># Transfer attributes on `data`, pooling as required.</span>
+        <span class="n">pooled_data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transfer_attributes</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">pooled_data</span><span class="p">)</span>
+
+        <span class="c1"># Reconstruct Batch Attributes</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">):</span>  <span class="c1"># if a Batch object</span>
+            <span class="n">pooled_data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_reconstruct_batch</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">pooled_data</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">pooled_data</span></div>
+
+
+    <span class="k">def</span> <span class="nf">_reconstruct_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">original</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">pooled</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+        <span class="n">pooled</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_slice_dict</span><span class="p">(</span><span class="n">original</span><span class="p">,</span> <span class="n">pooled</span><span class="p">)</span>
+        <span class="n">pooled</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_inc_dict</span><span class="p">(</span><span class="n">original</span><span class="p">,</span> <span class="n">pooled</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">pooled</span>
+
+    <span class="k">def</span> <span class="nf">_add_slice_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">original</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">pooled</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+        <span class="c1"># Copy original slice_dict and count nodes in each graph in pooled batch</span>
+        <span class="n">slice_dict</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">original</span><span class="o">.</span><span class="n">_slice_dict</span><span class="p">)</span>
+        <span class="n">_</span><span class="p">,</span> <span class="n">counts</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unique_consecutive</span><span class="p">(</span><span class="n">pooled</span><span class="o">.</span><span class="n">batch</span><span class="p">,</span> <span class="n">return_counts</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+        <span class="c1"># Reconstruct the entry in slice_dict for pulsemaps - only these are affected by pooling</span>
+        <span class="n">pulsemap_slice</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span>
+        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">counts</span><span class="p">)):</span>
+            <span class="n">pulsemap_slice</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">pulsemap_slice</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">counts</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
+
+        <span class="c1"># Identifies pulsemap entries in slice_dict and set them to pulsemap_slice</span>
+        <span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="n">slice_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+            <span class="k">if</span> <span class="p">(</span><span class="n">original</span><span class="o">.</span><span class="n">_num_graphs</span><span class="p">)</span> <span class="o">==</span> <span class="n">slice_dict</span><span class="p">[</span><span class="n">field</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
+                <span class="k">pass</span>  <span class="c1"># not pulsemap, so skip</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="n">slice_dict</span><span class="p">[</span><span class="n">field</span><span class="p">]</span> <span class="o">=</span> <span class="n">pulsemap_slice</span>
+        <span class="n">pooled</span><span class="o">.</span><span class="n">_slice_dict</span> <span class="o">=</span> <span class="n">slice_dict</span>
+        <span class="k">return</span> <span class="n">pooled</span>
+
+    <span class="k">def</span> <span class="nf">_add_inc_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">original</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">pooled</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+        <span class="c1"># not changed by coarsening</span>
+        <span class="n">pooled</span><span class="o">.</span><span class="n">_inc_dict</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">original</span><span class="o">.</span><span class="n">_inc_dict</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">pooled</span></div>
+
+
+
+<div class="viewcode-block" id="AttributeCoarsening">
+<a class="viewcode-back" href="../../../api/graphnet.models.coarsening.html#graphnet.models.coarsening.AttributeCoarsening">[docs]</a>
+<span class="k">class</span> <span class="nc">AttributeCoarsening</span><span class="p">(</span><span class="n">Coarsening</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Coarsen pulses based on specified attributes."""</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">attributes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+        <span class="n">reduce</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"avg"</span><span class="p">,</span>
+        <span class="n">transfer_attributes</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct `SimpleCoarsening`."""</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_attributes</span> <span class="o">=</span> <span class="n">attributes</span>
+
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">reduce</span><span class="p">,</span> <span class="n">transfer_attributes</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_perform_clustering</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">LongTensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Cluster nodes in `data` by assigning a cluster index to each."""</span>
+        <span class="n">dom_index</span> <span class="o">=</span> <span class="n">group_by</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_attributes</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">dom_index</span></div>
+
+
+
+<div class="viewcode-block" id="DOMCoarsening">
+<a class="viewcode-back" href="../../../api/graphnet.models.coarsening.html#graphnet.models.coarsening.DOMCoarsening">[docs]</a>
+<span class="k">class</span> <span class="nc">DOMCoarsening</span><span class="p">(</span><span class="n">Coarsening</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Coarsen pulses to DOM-level."""</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">reduce</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"avg"</span><span class="p">,</span>
+        <span class="n">transfer_attributes</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
+        <span class="n">keys</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Cluster pulses on the same DOM."""</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">reduce</span><span class="p">,</span> <span class="n">transfer_attributes</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">keys</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_keys</span> <span class="o">=</span> <span class="p">[</span>
+                <span class="s2">"dom_x"</span><span class="p">,</span>
+                <span class="s2">"dom_y"</span><span class="p">,</span>
+                <span class="s2">"dom_z"</span><span class="p">,</span>
+                <span class="s2">"rde"</span><span class="p">,</span>
+                <span class="s2">"pmt_area"</span><span class="p">,</span>
+            <span class="p">]</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_keys</span> <span class="o">=</span> <span class="n">keys</span>
+
+    <span class="k">def</span> <span class="nf">_perform_clustering</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">LongTensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Cluster nodes in `data` by assigning a cluster index to each."""</span>
+        <span class="n">dom_index</span> <span class="o">=</span> <span class="n">group_by</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_keys</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">dom_index</span></div>
+
+
+
+<div class="viewcode-block" id="CustomDOMCoarsening">
+<a class="viewcode-back" href="../../../api/graphnet.models.coarsening.html#graphnet.models.coarsening.CustomDOMCoarsening">[docs]</a>
+<span class="k">class</span> <span class="nc">CustomDOMCoarsening</span><span class="p">(</span><span class="n">DOMCoarsening</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Coarsen pulses to DOM-level with additional attributes."""</span>
+
+    <span class="k">def</span> <span class="nf">_additional_features</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Perform Additional poolings of feature tensor `x` on `data`."""</span>
+        <span class="n">batch</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span>
+
+        <span class="n">features</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">features</span>
+        <span class="k">if</span> <span class="n">batch</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">features</span> <span class="o">=</span> <span class="p">[</span><span class="n">feats</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">feats</span> <span class="ow">in</span> <span class="n">features</span><span class="p">]</span>
+
+        <span class="n">ix_time</span> <span class="o">=</span> <span class="n">features</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="s2">"dom_time"</span><span class="p">)</span>
+        <span class="n">ix_charge</span> <span class="o">=</span> <span class="n">features</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="s2">"charge"</span><span class="p">)</span>
+
+        <span class="n">time</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="n">ix_time</span><span class="p">]</span>
+        <span class="n">charge</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="n">ix_charge</span><span class="p">]</span>
+
+        <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span>
+            <span class="p">(</span>
+                <span class="n">min_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">time</span><span class="p">,</span> <span class="n">batch</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span>
+                <span class="n">max_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">time</span><span class="p">,</span> <span class="n">batch</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span>
+                <span class="n">std_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">time</span><span class="p">,</span> <span class="n">batch</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span>
+                <span class="n">min_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">charge</span><span class="p">,</span> <span class="n">batch</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span>
+                <span class="n">max_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">charge</span><span class="p">,</span> <span class="n">batch</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span>
+                <span class="n">std_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">charge</span><span class="p">,</span> <span class="n">batch</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span>
+                <span class="n">sum_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">charge</span><span class="p">),</span> <span class="n">batch</span><span class="p">)[</span>
+                    <span class="mi">0</span>
+                <span class="p">],</span>  <span class="c1"># Num. nodes (pulses) per cluster (DOM)</span>
+            <span class="p">),</span>
+            <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">x</span></div>
+
+
+
+<div class="viewcode-block" id="DOMAndTimeWindowCoarsening">
+<a class="viewcode-back" href="../../../api/graphnet.models.coarsening.html#graphnet.models.coarsening.DOMAndTimeWindowCoarsening">[docs]</a>
+<span class="k">class</span> <span class="nc">DOMAndTimeWindowCoarsening</span><span class="p">(</span><span class="n">Coarsening</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Coarsen pulses to DOM-level, with additional time-window clustering."""</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">time_window</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
+        <span class="n">reduce</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"avg"</span><span class="p">,</span>
+        <span class="n">transfer_attributes</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
+        <span class="n">keys</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span>
+            <span class="s2">"dom_x"</span><span class="p">,</span>
+            <span class="s2">"dom_y"</span><span class="p">,</span>
+            <span class="s2">"dom_z"</span><span class="p">,</span>
+            <span class="s2">"rde"</span><span class="p">,</span>
+            <span class="s2">"pmt_area"</span><span class="p">,</span>
+        <span class="p">],</span>
+        <span class="n">time_key</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"dom_time"</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Cluster pulses on the same DOM within `time_window`."""</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">reduce</span><span class="p">,</span> <span class="n">transfer_attributes</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_time_window</span> <span class="o">=</span> <span class="n">time_window</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_cluster_method</span> <span class="o">=</span> <span class="n">DBSCAN</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_time_window</span><span class="p">,</span> <span class="n">min_samples</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_keys</span> <span class="o">=</span> <span class="n">keys</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_time_key</span> <span class="o">=</span> <span class="n">time_key</span>
+
+    <span class="k">def</span> <span class="nf">_perform_clustering</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">LongTensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Cluster nodes in `data` by assigning a cluster index to each."""</span>
+        <span class="n">dom_index</span> <span class="o">=</span> <span class="n">group_by</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_keys</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">features</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">features</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">features</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">features</span>
+
+        <span class="n">ix_time</span> <span class="o">=</span> <span class="n">features</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_time_key</span><span class="p">)</span>
+        <span class="n">hit_times</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="n">ix_time</span><span class="p">]</span>
+
+        <span class="c1"># Scale up dom_index to make sure clusters are well separated</span>
+        <span class="n">times_and_domids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span>
+            <span class="p">[</span>
+                <span class="n">hit_times</span><span class="p">,</span>
+                <span class="n">dom_index</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">_time_window</span> <span class="o">*</span> <span class="mi">10</span><span class="p">,</span>
+            <span class="p">]</span>
+        <span class="p">)</span><span class="o">.</span><span class="n">T</span>
+        <span class="n">clusters</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_cluster_method</span><span class="o">.</span><span class="n">fit_predict</span><span class="p">(</span><span class="n">times_and_domids</span><span class="o">.</span><span class="n">cpu</span><span class="p">()),</span>
+            <span class="n">device</span><span class="o">=</span><span class="n">hit_times</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">clusters</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/components/layers.html b/_modules/graphnet/models/components/layers.html
new file mode 100644
index 000000000..555ac857b
--- /dev/null
+++ b/_modules/graphnet/models/components/layers.html
@@ -0,0 +1,579 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.components.layers &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/components/layers" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.components.layers </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-components-layers--page-root">Source code for graphnet.models.components.layers</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Class(es) implementing layers to be used in `graphnet` models."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span>
+
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch.functional</span> <span class="kn">import</span> <span class="n">Tensor</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.nn</span> <span class="kn">import</span> <span class="n">EdgeConv</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.nn.pool</span> <span class="kn">import</span> <span class="n">knn_graph</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.typing</span> <span class="kn">import</span> <span class="n">Adj</span><span class="p">,</span> <span class="n">PairTensor</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.nn.conv</span> <span class="kn">import</span> <span class="n">MessagePassing</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.nn.inits</span> <span class="kn">import</span> <span class="n">reset</span>
+<span class="kn">from</span> <span class="nn">torch.nn.modules</span> <span class="kn">import</span> <span class="n">TransformerEncoder</span><span class="p">,</span> <span class="n">TransformerEncoderLayer</span>
+<span class="kn">from</span> <span class="nn">torch.nn.modules.normalization</span> <span class="kn">import</span> <span class="n">LayerNorm</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.utils</span> <span class="kn">import</span> <span class="n">to_dense_batch</span>
+<span class="kn">from</span> <span class="nn">pytorch_lightning</span> <span class="kn">import</span> <span class="n">LightningModule</span>
+
+
+<div class="viewcode-block" id="DynEdgeConv">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynEdgeConv">[docs]</a>
+<span class="k">class</span> <span class="nc">DynEdgeConv</span><span class="p">(</span><span class="n">EdgeConv</span><span class="p">,</span> <span class="n">LightningModule</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Dynamical edge convolution layer."""</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">nn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">,</span>
+        <span class="n">aggr</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"max"</span><span class="p">,</span>
+        <span class="n">nb_neighbors</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span>
+        <span class="n">features_subset</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">slice</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct `DynEdgeConv`.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            nn: The MLP/torch.Module to be used within the `EdgeConv`.</span>
+<span class="sd">            aggr: Aggregation method to be used with `EdgeConv`.</span>
+<span class="sd">            nb_neighbors: Number of neighbours to be clustered after the</span>
+<span class="sd">                `EdgeConv` operation.</span>
+<span class="sd">            features_subset: Subset of features in `Data.x` that should be used</span>
+<span class="sd">                when dynamically performing the new graph clustering after the</span>
+<span class="sd">                `EdgeConv` operation. Defaults to all features.</span>
+<span class="sd">            **kwargs: Additional features to be passed to `EdgeConv`.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Check(s)</span>
+        <span class="k">if</span> <span class="n">features_subset</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">features_subset</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span>  <span class="c1"># Use all features</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">features_subset</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">slice</span><span class="p">))</span>
+
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">nn</span><span class="o">=</span><span class="n">nn</span><span class="p">,</span> <span class="n">aggr</span><span class="o">=</span><span class="n">aggr</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+
+        <span class="c1"># Additional member variables</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">nb_neighbors</span> <span class="o">=</span> <span class="n">nb_neighbors</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">features_subset</span> <span class="o">=</span> <span class="n">features_subset</span>
+
+<div class="viewcode-block" id="DynEdgeConv.forward">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynEdgeConv.forward">[docs]</a>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">:</span> <span class="n">Adj</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Forward pass."""</span>
+        <span class="c1"># Standard EdgeConv forward pass</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">)</span>
+
+        <span class="c1"># Recompute adjacency</span>
+        <span class="n">edge_index</span> <span class="o">=</span> <span class="n">knn_graph</span><span class="p">(</span>
+            <span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">features_subset</span><span class="p">],</span>
+            <span class="n">k</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_neighbors</span><span class="p">,</span>
+            <span class="n">batch</span><span class="o">=</span><span class="n">batch</span><span class="p">,</span>
+        <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span></div>
+</div>
+
+
+
+<div class="viewcode-block" id="EdgeConvTito">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito">[docs]</a>
+<span class="k">class</span> <span class="nc">EdgeConvTito</span><span class="p">(</span><span class="n">MessagePassing</span><span class="p">,</span> <span class="n">LightningModule</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Implementation of EdgeConvTito layer used in TITO solution for.</span>
+
+<span class="sd">    'IceCube - Neutrinos in Deep' kaggle competition.</span>
+<span class="sd">    """</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">nn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">,</span>
+        <span class="n">aggr</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"max"</span><span class="p">,</span>
+        <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct `EdgeConvTito`.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            nn: The MLP/torch.Module to be used within the `EdgeConvTito`.</span>
+<span class="sd">            aggr: Aggregation method to be used with `EdgeConvTito`.</span>
+<span class="sd">            **kwargs: Additional features to be passed to `EdgeConvTito`.</span>
+<span class="sd">        """</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">aggr</span><span class="o">=</span><span class="n">aggr</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">nn</span> <span class="o">=</span> <span class="n">nn</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">reset_parameters</span><span class="p">()</span>
+
+<div class="viewcode-block" id="EdgeConvTito.reset_parameters">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito.reset_parameters">[docs]</a>
+    <span class="k">def</span> <span class="nf">reset_parameters</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Reset all learnable parameters of the module."""</span>
+        <span class="n">reset</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nn</span><span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="EdgeConvTito.forward">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito.forward">[docs]</a>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">PairTensor</span><span class="p">],</span> <span class="n">edge_index</span><span class="p">:</span> <span class="n">Adj</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Forward pass."""</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
+            <span class="n">x</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
+        <span class="c1"># propagate_type: (x: PairTensor)</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">propagate</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="EdgeConvTito.message">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito.message">[docs]</a>
+    <span class="k">def</span> <span class="nf">message</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x_i</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">x_j</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Edgeconvtito message passing."""</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">nn</span><span class="p">(</span>
+            <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">x_i</span><span class="p">,</span> <span class="n">x_j</span> <span class="o">-</span> <span class="n">x_i</span><span class="p">,</span> <span class="n">x_j</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
+        <span class="p">)</span>  <span class="c1"># EdgeConvTito</span></div>
+
+
+    <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Print out module name."""</span>
+        <span class="k">return</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">(nn=</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">nn</span><span class="si">}</span><span class="s2">)"</span></div>
+
+
+
+<div class="viewcode-block" id="DynTrans">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynTrans">[docs]</a>
+<span class="k">class</span> <span class="nc">DynTrans</span><span class="p">(</span><span class="n">EdgeConvTito</span><span class="p">,</span> <span class="n">LightningModule</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Implementation of dynTrans1 layer used in TITO solution for.</span>
+
+<span class="sd">    'IceCube - Neutrinos in Deep' kaggle competition.</span>
+<span class="sd">    """</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">layer_sizes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">aggr</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"max"</span><span class="p">,</span>
+        <span class="n">features_subset</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">slice</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">n_head</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span>
+        <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct `DynTrans`.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            nn: The MLP/torch.Module to be used within the `DynTrans`.</span>
+<span class="sd">            layer_sizes: List of layer sizes to be used in `DynTrans`.</span>
+<span class="sd">            aggr: Aggregation method to be used with `DynTrans`.</span>
+<span class="sd">            features_subset: Subset of features in `Data.x` that should be used</span>
+<span class="sd">                when dynamically performing the new graph clustering after the</span>
+<span class="sd">                `EdgeConv` operation. Defaults to all features.</span>
+<span class="sd">            n_head: Number of heads to be used in the multiheadattention models.</span>
+<span class="sd">            **kwargs: Additional features to be passed to `DynTrans`.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Check(s)</span>
+        <span class="k">if</span> <span class="n">features_subset</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">features_subset</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span>  <span class="c1"># Use all features</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">features_subset</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">slice</span><span class="p">))</span>
+
+        <span class="k">if</span> <span class="n">layer_sizes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">256</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">256</span><span class="p">]</span>
+        <span class="n">layers</span> <span class="o">=</span> <span class="p">[]</span>
+        <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span>
+            <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span>
+        <span class="p">):</span>
+            <span class="k">if</span> <span class="n">ix</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
+                <span class="n">nb_in</span> <span class="o">*=</span> <span class="mi">3</span>  <span class="c1"># edgeConv1</span>
+            <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">))</span>
+            <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">())</span>
+        <span class="n">d_model</span> <span class="o">=</span> <span class="n">nb_out</span>
+
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">nn</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">),</span> <span class="n">aggr</span><span class="o">=</span><span class="n">aggr</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+
+        <span class="c1"># Additional member variables</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">features_subset</span> <span class="o">=</span> <span class="n">features_subset</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)</span>  <span class="c1"># lNorm</span>
+
+        <span class="c1"># Transformer layer(s)</span>
+        <span class="n">encoder_layer</span> <span class="o">=</span> <span class="n">TransformerEncoderLayer</span><span class="p">(</span>
+            <span class="n">d_model</span><span class="o">=</span><span class="n">d_model</span><span class="p">,</span>
+            <span class="n">nhead</span><span class="o">=</span><span class="n">n_head</span><span class="p">,</span>
+            <span class="n">batch_first</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
+            <span class="n">norm_first</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
+        <span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_transformer_encoder</span> <span class="o">=</span> <span class="n">TransformerEncoder</span><span class="p">(</span>
+            <span class="n">encoder_layer</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">1</span>
+        <span class="p">)</span>
+
+<div class="viewcode-block" id="DynTrans.forward">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynTrans.forward">[docs]</a>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">:</span> <span class="n">Adj</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Forward pass."""</span>
+        <span class="n">x_out</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">)</span>
+
+        <span class="k">if</span> <span class="n">x_out</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
+            <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">x_out</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">x</span> <span class="o">=</span> <span class="n">x_out</span>
+
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>  <span class="c1"># lNorm</span>
+
+        <span class="c1"># Transformer layer</span>
+        <span class="n">x</span><span class="p">,</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">to_dense_batch</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transformer_encoder</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">src_key_padding_mask</span><span class="o">=~</span><span class="n">mask</span><span class="p">)</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span>
+
+        <span class="k">return</span> <span class="n">x</span></div>
+</div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/components/pool.html b/_modules/graphnet/models/components/pool.html
new file mode 100644
index 000000000..2523dfb69
--- /dev/null
+++ b/_modules/graphnet/models/components/pool.html
@@ -0,0 +1,656 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.components.pool &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/components/pool" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.components.pool </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-components-pool--page-root">Source code for graphnet.models.components.pool</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Functions for performing pooling/clustering/coarsening."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span>
+
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">Tensor</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.nn.pool.consecutive</span> <span class="kn">import</span> <span class="n">consecutive_cluster</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.nn.pool.pool</span> <span class="kn">import</span> <span class="n">pool_edge</span><span class="p">,</span> <span class="n">pool_batch</span><span class="p">,</span> <span class="n">pool_pos</span>
+<span class="kn">from</span> <span class="nn">torch_scatter</span> <span class="kn">import</span> <span class="n">scatter</span><span class="p">,</span> <span class="n">scatter_std</span>
+
+<span class="kn">from</span> <span class="nn">torch_geometric.nn.pool</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">avg_pool</span><span class="p">,</span>
+    <span class="n">max_pool</span><span class="p">,</span>
+    <span class="n">avg_pool_x</span><span class="p">,</span>
+    <span class="n">max_pool_x</span><span class="p">,</span>
+<span class="p">)</span>
+
+
+<div class="viewcode-block" id="min_pool">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.min_pool">[docs]</a>
+<span class="k">def</span> <span class="nf">min_pool</span><span class="p">(</span>
+    <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">transform</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Any</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Perform min-pooling of `Data`.</span>
+
+<span class="sd">    Like `max_pool, just negating `data.x`.</span>
+<span class="sd">    """</span>
+    <span class="n">data</span><span class="o">.</span><span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="n">data</span><span class="o">.</span><span class="n">x</span>
+    <span class="n">data_pooled</span> <span class="o">=</span> <span class="n">max_pool</span><span class="p">(</span>
+        <span class="n">cluster</span><span class="p">,</span>
+        <span class="n">data</span><span class="p">,</span>
+        <span class="n">transform</span><span class="p">,</span>
+    <span class="p">)</span>
+    <span class="n">data</span><span class="o">.</span><span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="n">data</span><span class="o">.</span><span class="n">x</span>
+    <span class="n">data_pooled</span><span class="o">.</span><span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="n">data_pooled</span><span class="o">.</span><span class="n">x</span>
+    <span class="k">return</span> <span class="n">data_pooled</span></div>
+
+
+
+<div class="viewcode-block" id="min_pool_x">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.min_pool_x">[docs]</a>
+<span class="k">def</span> <span class="nf">min_pool_x</span><span class="p">(</span>
+    <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
+    <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
+    <span class="n">batch</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
+    <span class="n">size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Perform min-pooling of `Tensor`.</span>
+
+<span class="sd">    Like `max_pool_x, just negating `x`.</span>
+<span class="sd">    """</span>
+    <span class="n">ret</span> <span class="o">=</span> <span class="n">max_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="o">-</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">size</span><span class="p">)</span>
+    <span class="k">if</span> <span class="n">size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="k">return</span> <span class="p">(</span><span class="o">-</span><span class="n">ret</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">ret</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
+    <span class="k">else</span><span class="p">:</span>
+        <span class="k">return</span> <span class="o">-</span><span class="n">ret</span></div>
+
+
+
+<div class="viewcode-block" id="sum_pool_and_distribute">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool_and_distribute">[docs]</a>
+<span class="k">def</span> <span class="nf">sum_pool_and_distribute</span><span class="p">(</span>
+    <span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
+    <span class="n">cluster_index</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
+    <span class="n">batch</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LongTensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Sum-pool values and distribute result to the individual nodes."""</span>
+    <span class="k">if</span> <span class="n">batch</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="n">batch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span><span class="o">.</span><span class="n">long</span><span class="p">()</span>
+    <span class="n">tensor_pooled</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">sum_pool_x</span><span class="p">(</span><span class="n">cluster_index</span><span class="p">,</span> <span class="n">tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+    <span class="n">inv</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">consecutive_cluster</span><span class="p">(</span><span class="n">cluster_index</span><span class="p">)</span>
+    <span class="n">tensor_unpooled</span> <span class="o">=</span> <span class="n">tensor_pooled</span><span class="p">[</span><span class="n">inv</span><span class="p">]</span>
+    <span class="k">return</span> <span class="n">tensor_unpooled</span></div>
+
+
+
+<span class="k">def</span> <span class="nf">_group_identical</span><span class="p">(</span>
+    <span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LongTensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">LongTensor</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Group rows in `tensor` that are identical.</span>
+
+<span class="sd">    Args:</span>
+<span class="sd">        tensor: Tensor of shape [N, F].</span>
+<span class="sd">        batch: Batch indices, to only group identical rows within batches.</span>
+
+<span class="sd">    Returns:</span>
+<span class="sd">        List of group indices, from 0 to num. groups - 1, assigning all</span>
+<span class="sd">            identical rows to the same group.</span>
+<span class="sd">    """</span>
+    <span class="k">if</span> <span class="n">batch</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="n">tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">batch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span> <span class="n">tensor</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+    <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">return_inverse</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="nb">sorted</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span>
+
+
+<div class="viewcode-block" id="group_by">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.group_by">[docs]</a>
+<span class="k">def</span> <span class="nf">group_by</span><span class="p">(</span><span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">],</span> <span class="n">keys</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">LongTensor</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Group nodes in `data` that have identical values of `keys`.</span>
+
+<span class="sd">    This grouping is done with in each event in case of batching. This allows</span>
+<span class="sd">    for, e.g., assigning the same index to all pulses on the same PMT or DOM in</span>
+<span class="sd">    the same event. This can be used for coarsening graphs, e.g., from pulse-</span>
+<span class="sd">    level to DOM-level by aggregating feature across each group returned by this</span>
+<span class="sd">    method.</span>
+
+<span class="sd">    Example:</span>
+<span class="sd">      Given:</span>
+<span class="sd">        data.f1 = [1,1,2,2,2]</span>
+<span class="sd">        data.f2 = [6,7,7,7,8]</span>
+<span class="sd">      Calls:</span>
+<span class="sd">        groupby(data, ['f1'])       -&gt; [0, 0, 1, 1, 1]</span>
+<span class="sd">        groupby(data, ['f2'])       -&gt; [0, 1, 1, 1, 2]</span>
+<span class="sd">        groupby(data, ['f1', 'f2']) -&gt; [0, 1, 2, 2, 3]</span>
+<span class="sd">    """</span>
+    <span class="n">features</span> <span class="o">=</span> <span class="p">[</span><span class="nb">getattr</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">keys</span><span class="p">]</span>
+    <span class="n">tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">features</span><span class="p">)</span><span class="o">.</span><span class="n">T</span>  <span class="c1"># .int()  @TODO: Required? Use rounding?</span>
+    <span class="n">batch</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="s2">"batch"</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
+    <span class="n">index</span> <span class="o">=</span> <span class="n">_group_identical</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+    <span class="k">return</span> <span class="n">index</span></div>
+
+
+
+<div class="viewcode-block" id="group_pulses_to_dom">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.group_pulses_to_dom">[docs]</a>
+<span class="k">def</span> <span class="nf">group_pulses_to_dom</span><span class="p">(</span><span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Group pulses on the same DOM, using DOM and string number."""</span>
+    <span class="n">data</span><span class="o">.</span><span class="n">dom_index</span> <span class="o">=</span> <span class="n">group_by</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="p">[</span><span class="s2">"dom_number"</span><span class="p">,</span> <span class="s2">"string"</span><span class="p">])</span>
+    <span class="k">return</span> <span class="n">data</span></div>
+
+
+
+<div class="viewcode-block" id="group_pulses_to_pmt">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.group_pulses_to_pmt">[docs]</a>
+<span class="k">def</span> <span class="nf">group_pulses_to_pmt</span><span class="p">(</span><span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Group pulses on the same PMT, using PMT, DOM, and string number."""</span>
+    <span class="n">data</span><span class="o">.</span><span class="n">pmt_index</span> <span class="o">=</span> <span class="n">group_by</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="p">[</span><span class="s2">"pmt_number"</span><span class="p">,</span> <span class="s2">"dom_number"</span><span class="p">,</span> <span class="s2">"string"</span><span class="p">])</span>
+    <span class="k">return</span> <span class="n">data</span></div>
+
+
+
+<span class="c1"># Below mirroring `torch_geometric.nn.pool.{avg,max}_pool.py`.</span>
+<span class="k">def</span> <span class="nf">_sum_pool_x</span><span class="p">(</span>
+    <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+    <span class="k">return</span> <span class="n">scatter</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">cluster</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim_size</span><span class="o">=</span><span class="n">size</span><span class="p">,</span> <span class="n">reduce</span><span class="o">=</span><span class="s2">"sum"</span><span class="p">)</span>
+
+
+<span class="k">def</span> <span class="nf">_std_pool_x</span><span class="p">(</span>
+    <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+    <span class="k">return</span> <span class="n">scatter_std</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">cluster</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim_size</span><span class="o">=</span><span class="n">size</span><span class="p">,</span> <span class="n">unbiased</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
+
+
+<div class="viewcode-block" id="sum_pool_x">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool_x">[docs]</a>
+<span class="k">def</span> <span class="nf">sum_pool_x</span><span class="p">(</span>
+    <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
+    <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
+    <span class="n">batch</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
+    <span class="n">size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">    </span><span class="sa">r</span><span class="sd">"""Sum-pool node features according to the clustering defined in `cluster`.</span>
+
+<span class="sd">    Args:</span>
+<span class="sd">        cluster: Cluster vector :math:`\mathbf{c} \in \{ 0,</span>
+<span class="sd">            \ldots, N - 1 \}^N`, which assigns each node to a specific cluster.</span>
+<span class="sd">        x: Node feature matrix</span>
+<span class="sd">            :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.</span>
+<span class="sd">        batch: Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,</span>
+<span class="sd">            B-1\}}^N`, which assigns each node to a specific example.</span>
+<span class="sd">        size: The maximum number of clusters in a single</span>
+<span class="sd">            example. This property is useful to obtain a batch-wise dense</span>
+<span class="sd">            representation, *e.g.* for applying FC layers, but should only be</span>
+<span class="sd">            used if the size of the maximum number of clusters per example is</span>
+<span class="sd">            known in advance.</span>
+<span class="sd">    """</span>
+    <span class="k">if</span> <span class="n">size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="n">batch_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span> <span class="o">+</span> <span class="mi">1</span>
+        <span class="k">return</span> <span class="n">_sum_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">size</span><span class="p">),</span> <span class="kc">None</span>
+
+    <span class="n">cluster</span><span class="p">,</span> <span class="n">perm</span> <span class="o">=</span> <span class="n">consecutive_cluster</span><span class="p">(</span><span class="n">cluster</span><span class="p">)</span>
+    <span class="n">x</span> <span class="o">=</span> <span class="n">_sum_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
+    <span class="n">batch</span> <span class="o">=</span> <span class="n">pool_batch</span><span class="p">(</span><span class="n">perm</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+
+    <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">batch</span></div>
+
+
+
+<div class="viewcode-block" id="std_pool_x">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.std_pool_x">[docs]</a>
+<span class="k">def</span> <span class="nf">std_pool_x</span><span class="p">(</span>
+    <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
+    <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
+    <span class="n">batch</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
+    <span class="n">size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">    </span><span class="sa">r</span><span class="sd">"""Std-pool node features according to the clustering defined in `cluster`.</span>
+
+<span class="sd">    Args:</span>
+<span class="sd">        cluster: Cluster vector :math:`\mathbf{c} \in \{ 0,</span>
+<span class="sd">            \ldots, N - 1 \}^N`, which assigns each node to a specific cluster.</span>
+<span class="sd">        x: Node feature matrix</span>
+<span class="sd">            :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.</span>
+<span class="sd">        batch: Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,</span>
+<span class="sd">            B-1\}}^N`, which assigns each node to a specific example.</span>
+<span class="sd">        size: The maximum number of clusters in a single</span>
+<span class="sd">            example. This property is useful to obtain a batch-wise dense</span>
+<span class="sd">            representation, *e.g.* for applying FC layers, but should only be</span>
+<span class="sd">            used if the size of the maximum number of clusters per example is</span>
+<span class="sd">            known in advance.</span>
+<span class="sd">    """</span>
+    <span class="k">if</span> <span class="n">size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="n">batch_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span> <span class="o">+</span> <span class="mi">1</span>
+        <span class="k">return</span> <span class="n">_std_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">size</span><span class="p">),</span> <span class="kc">None</span>
+
+    <span class="n">cluster</span><span class="p">,</span> <span class="n">perm</span> <span class="o">=</span> <span class="n">consecutive_cluster</span><span class="p">(</span><span class="n">cluster</span><span class="p">)</span>
+    <span class="n">x</span> <span class="o">=</span> <span class="n">_std_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
+    <span class="n">batch</span> <span class="o">=</span> <span class="n">pool_batch</span><span class="p">(</span><span class="n">perm</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+
+    <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">batch</span></div>
+
+
+
+<div class="viewcode-block" id="sum_pool">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool">[docs]</a>
+<span class="k">def</span> <span class="nf">sum_pool</span><span class="p">(</span>
+    <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">transform</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">    </span><span class="sa">r</span><span class="sd">"""Pool and coarsen graph according to the clustering defined in `cluster`.</span>
+
+<span class="sd">    All nodes within the same cluster will be represented as one node.</span>
+<span class="sd">    Final node features are defined by the *sum* of features of all nodes</span>
+<span class="sd">    within the same cluster, node positions are averaged and edge indices are</span>
+<span class="sd">    defined to be the union of the edge indices of all nodes within the same</span>
+<span class="sd">    cluster.</span>
+
+<span class="sd">    Args:</span>
+<span class="sd">        cluster: Cluster vector :math:`\mathbf{c} \in \{ 0,</span>
+<span class="sd">            \ldots, N - 1 \}^N`, which assigns each node to a specific cluster.</span>
+<span class="sd">        data: Graph data object.</span>
+<span class="sd">        transform: A function/transform that takes in the</span>
+<span class="sd">            coarsened and pooled :obj:`torch_geometric.data.Data` object and</span>
+<span class="sd">            returns a transformed version.</span>
+<span class="sd">    """</span>
+    <span class="n">cluster</span><span class="p">,</span> <span class="n">perm</span> <span class="o">=</span> <span class="n">consecutive_cluster</span><span class="p">(</span><span class="n">cluster</span><span class="p">)</span>
+
+    <span class="n">x</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">_sum_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">)</span>
+    <span class="n">index</span><span class="p">,</span> <span class="n">attr</span> <span class="o">=</span> <span class="n">pool_edge</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_attr</span><span class="p">)</span>
+    <span class="n">batch</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">pool_batch</span><span class="p">(</span><span class="n">perm</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span><span class="p">)</span>
+    <span class="n">pos</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">pos</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">pool_pos</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">pos</span><span class="p">)</span>
+
+    <span class="n">data</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">batch</span><span class="o">=</span><span class="n">batch</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="o">=</span><span class="n">index</span><span class="p">,</span> <span class="n">edge_attr</span><span class="o">=</span><span class="n">attr</span><span class="p">,</span> <span class="n">pos</span><span class="o">=</span><span class="n">pos</span><span class="p">)</span>
+
+    <span class="k">if</span> <span class="n">transform</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="n">data</span> <span class="o">=</span> <span class="n">transform</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
+
+    <span class="k">return</span> <span class="n">data</span></div>
+
+
+
+<div class="viewcode-block" id="std_pool">
+<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.std_pool">[docs]</a>
+<span class="k">def</span> <span class="nf">std_pool</span><span class="p">(</span>
+    <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">transform</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">    </span><span class="sa">r</span><span class="sd">"""Pool and coarsen graph according to the clustering defined in `cluster`.</span>
+
+<span class="sd">    All nodes within the same cluster will be represented as one node.</span>
+<span class="sd">    Final node features are defined by the *std* of features of all nodes</span>
+<span class="sd">    within the same cluster, node positions are averaged and edge indices are</span>
+<span class="sd">    defined to be the union of the edge indices of all nodes within the same</span>
+<span class="sd">    cluster.</span>
+
+<span class="sd">    Args:</span>
+<span class="sd">        cluster: Cluster vector :math:`\mathbf{c} \in \{ 0,</span>
+<span class="sd">            \ldots, N - 1 \}^N`, which assigns each node to a specific cluster.</span>
+<span class="sd">        data: Graph data object.</span>
+<span class="sd">        transform: A function/transform that takes in the</span>
+<span class="sd">            coarsened and pooled :obj:`torch_geometric.data.Data` object and</span>
+<span class="sd">            returns a transformed version.</span>
+<span class="sd">    """</span>
+    <span class="n">cluster</span><span class="p">,</span> <span class="n">perm</span> <span class="o">=</span> <span class="n">consecutive_cluster</span><span class="p">(</span><span class="n">cluster</span><span class="p">)</span>
+
+    <span class="n">x</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">_std_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">)</span>
+    <span class="n">index</span><span class="p">,</span> <span class="n">attr</span> <span class="o">=</span> <span class="n">pool_edge</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_attr</span><span class="p">)</span>
+    <span class="n">batch</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">pool_batch</span><span class="p">(</span><span class="n">perm</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span><span class="p">)</span>
+    <span class="n">pos</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">pos</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">pool_pos</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">pos</span><span class="p">)</span>
+
+    <span class="n">data</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">batch</span><span class="o">=</span><span class="n">batch</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="o">=</span><span class="n">index</span><span class="p">,</span> <span class="n">edge_attr</span><span class="o">=</span><span class="n">attr</span><span class="p">,</span> <span class="n">pos</span><span class="o">=</span><span class="n">pos</span><span class="p">)</span>
+
+    <span class="k">if</span> <span class="n">transform</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="n">data</span> <span class="o">=</span> <span class="n">transform</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
+
+    <span class="k">return</span> <span class="n">data</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/detector/detector.html b/_modules/graphnet/models/detector/detector.html
new file mode 100644
index 000000000..eef62aefc
--- /dev/null
+++ b/_modules/graphnet/models/detector/detector.html
@@ -0,0 +1,421 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.detector.detector &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/detector/detector" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.detector.detector </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-detector-detector--page-root">Source code for graphnet.models.detector.detector</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Base detector-specific `Model` class(es)."""</span>
+
+<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">List</span>
+
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.decorators</span> <span class="kn">import</span> <span class="n">final</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span>
+
+
+<div class="viewcode-block" id="Detector">
+<a class="viewcode-back" href="../../../../api/graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector">[docs]</a>
+<span class="k">class</span> <span class="nc">Detector</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Base class for all detector-specific read-ins in graphnet."""</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct `Detector`."""</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>
+
+<div class="viewcode-block" id="Detector.feature_map">
+<a class="viewcode-back" href="../../../../api/graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector.feature_map">[docs]</a>
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">feature_map</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""List of features used/assumed by inheriting `Detector` objects."""</span></div>
+
+
+<div class="viewcode-block" id="Detector.forward">
+<a class="viewcode-back" href="../../../../api/graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector.forward">[docs]</a>
+    <span class="nd">@final</span>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span>  <span class="c1"># type: ignore</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">node_features</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Pre-process graph `Data` features and build graph adjacency."""</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_standardize</span><span class="p">(</span><span class="n">node_features</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="p">)</span></div>
+
+
+    <span class="nd">@final</span>
+    <span class="k">def</span> <span class="nf">_standardize</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">node_features</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+        <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">feature</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">node_feature_names</span><span class="p">):</span>
+            <span class="k">try</span><span class="p">:</span>
+                <span class="n">node_features</span><span class="p">[:,</span> <span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feature_map</span><span class="p">()[</span><span class="n">feature</span><span class="p">](</span>  <span class="c1"># type: ignore</span>
+                    <span class="n">node_features</span><span class="p">[:,</span> <span class="n">idx</span><span class="p">]</span>
+                <span class="p">)</span>
+            <span class="k">except</span> <span class="ne">KeyError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
+                    <span class="sa">f</span><span class="s2">"""No Standardization function found for '</span><span class="si">{</span><span class="n">feature</span><span class="si">}</span><span class="s2">'"""</span>
+                <span class="p">)</span>
+                <span class="k">raise</span> <span class="n">e</span>
+        <span class="k">return</span> <span class="n">node_features</span>
+
+    <span class="k">def</span> <span class="nf">_identity</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Apply no standardization to input."""</span>
+        <span class="k">return</span> <span class="n">x</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/detector/icecube.html b/_modules/graphnet/models/detector/icecube.html
new file mode 100644
index 000000000..76d9a8e16
--- /dev/null
+++ b/_modules/graphnet/models/detector/icecube.html
@@ -0,0 +1,528 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.detector.icecube &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/detector/icecube" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.detector.icecube </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-detector-icecube--page-root">Source code for graphnet.models.detector.icecube</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""IceCube-specific `Detector` class(es)."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Callable</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.models.detector.detector</span> <span class="kn">import</span> <span class="n">Detector</span>
+
+
+<div class="viewcode-block" id="IceCube86">
+<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCube86">[docs]</a>
+<span class="k">class</span> <span class="nc">IceCube86</span><span class="p">(</span><span class="n">Detector</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""`Detector` class for IceCube-86."""</span>
+
+<div class="viewcode-block" id="IceCube86.feature_map">
+<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCube86.feature_map">[docs]</a>
+    <span class="k">def</span> <span class="nf">feature_map</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Map standardization functions to each dimension of input data."""</span>
+        <span class="n">feature_map</span> <span class="o">=</span> <span class="p">{</span>
+            <span class="s2">"dom_x"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xyz</span><span class="p">,</span>
+            <span class="s2">"dom_y"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xyz</span><span class="p">,</span>
+            <span class="s2">"dom_z"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xyz</span><span class="p">,</span>
+            <span class="s2">"dom_time"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_time</span><span class="p">,</span>
+            <span class="s2">"charge"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_charge</span><span class="p">,</span>
+            <span class="s2">"rde"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_rde</span><span class="p">,</span>
+            <span class="s2">"pmt_area"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pmt_area</span><span class="p">,</span>
+        <span class="p">}</span>
+        <span class="k">return</span> <span class="n">feature_map</span></div>
+
+
+    <span class="k">def</span> <span class="nf">_dom_xyz</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">500.0</span>
+
+    <span class="k">def</span> <span class="nf">_dom_time</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="mf">1.0e04</span><span class="p">)</span> <span class="o">/</span> <span class="mf">3.0e4</span>
+
+    <span class="k">def</span> <span class="nf">_charge</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">log10</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_rde</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="mf">1.25</span><span class="p">)</span> <span class="o">/</span> <span class="mf">0.25</span>
+
+    <span class="k">def</span> <span class="nf">_pmt_area</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">0.05</span></div>
+
+
+
+<div class="viewcode-block" id="IceCubeKaggle">
+<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeKaggle">[docs]</a>
+<span class="k">class</span> <span class="nc">IceCubeKaggle</span><span class="p">(</span><span class="n">Detector</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""`Detector` class for Kaggle Competition."""</span>
+
+<div class="viewcode-block" id="IceCubeKaggle.feature_map">
+<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeKaggle.feature_map">[docs]</a>
+    <span class="k">def</span> <span class="nf">feature_map</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Map standardization functions to each dimension of input data."""</span>
+        <span class="n">feature_map</span> <span class="o">=</span> <span class="p">{</span>
+            <span class="s2">"x"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_xyz</span><span class="p">,</span>
+            <span class="s2">"y"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_xyz</span><span class="p">,</span>
+            <span class="s2">"z"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_xyz</span><span class="p">,</span>
+            <span class="s2">"time"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_time</span><span class="p">,</span>
+            <span class="s2">"charge"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_charge</span><span class="p">,</span>
+            <span class="s2">"auxiliary"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_identity</span><span class="p">,</span>
+        <span class="p">}</span>
+        <span class="k">return</span> <span class="n">feature_map</span></div>
+
+
+    <span class="k">def</span> <span class="nf">_xyz</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">500.0</span>
+
+    <span class="k">def</span> <span class="nf">_time</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="mf">1.0e04</span><span class="p">)</span> <span class="o">/</span> <span class="mf">3.0e4</span>
+
+    <span class="k">def</span> <span class="nf">_charge</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">log10</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="mf">3.0</span></div>
+
+
+
+<div class="viewcode-block" id="IceCubeDeepCore">
+<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeDeepCore">[docs]</a>
+<span class="k">class</span> <span class="nc">IceCubeDeepCore</span><span class="p">(</span><span class="n">Detector</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""`Detector` class for IceCube-DeepCore."""</span>
+
+<div class="viewcode-block" id="IceCubeDeepCore.feature_map">
+<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeDeepCore.feature_map">[docs]</a>
+    <span class="k">def</span> <span class="nf">feature_map</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Map standardization functions to each dimension of input data."""</span>
+        <span class="n">feature_map</span> <span class="o">=</span> <span class="p">{</span>
+            <span class="s2">"dom_x"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xy</span><span class="p">,</span>
+            <span class="s2">"dom_y"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xy</span><span class="p">,</span>
+            <span class="s2">"dom_z"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_z</span><span class="p">,</span>
+            <span class="s2">"dom_time"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_time</span><span class="p">,</span>
+            <span class="s2">"charge"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_identity</span><span class="p">,</span>
+            <span class="s2">"rde"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_rde</span><span class="p">,</span>
+            <span class="s2">"pmt_area"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pmt_area</span><span class="p">,</span>
+        <span class="p">}</span>
+        <span class="k">return</span> <span class="n">feature_map</span></div>
+
+
+    <span class="k">def</span> <span class="nf">_dom_xy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">100.0</span>
+
+    <span class="k">def</span> <span class="nf">_dom_z</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="mf">350.0</span><span class="p">)</span> <span class="o">/</span> <span class="mf">100.0</span>
+
+    <span class="k">def</span> <span class="nf">_dom_time</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="p">((</span><span class="n">x</span> <span class="o">/</span> <span class="mf">1.05e04</span><span class="p">)</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">*</span> <span class="mf">20.0</span>
+
+    <span class="k">def</span> <span class="nf">_rde</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="mf">1.25</span><span class="p">)</span> <span class="o">/</span> <span class="mf">0.25</span>
+
+    <span class="k">def</span> <span class="nf">_pmt_area</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">0.05</span></div>
+
+
+
+<div class="viewcode-block" id="IceCubeUpgrade">
+<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeUpgrade">[docs]</a>
+<span class="k">class</span> <span class="nc">IceCubeUpgrade</span><span class="p">(</span><span class="n">Detector</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""`Detector` class for IceCube-Upgrade."""</span>
+
+<div class="viewcode-block" id="IceCubeUpgrade.feature_map">
+<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeUpgrade.feature_map">[docs]</a>
+    <span class="k">def</span> <span class="nf">feature_map</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Map standardization functions to each dimension of input data."""</span>
+        <span class="n">feature_map</span> <span class="o">=</span> <span class="p">{</span>
+            <span class="s2">"dom_x"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xyz</span><span class="p">,</span>
+            <span class="s2">"dom_y"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xyz</span><span class="p">,</span>
+            <span class="s2">"dom_z"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xyz</span><span class="p">,</span>
+            <span class="s2">"dom_time"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_time</span><span class="p">,</span>
+            <span class="s2">"charge"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_charge</span><span class="p">,</span>
+            <span class="s2">"rde"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_identity</span><span class="p">,</span>
+            <span class="s2">"pmt_area"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pmt_area</span><span class="p">,</span>
+            <span class="s2">"string"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_string</span><span class="p">,</span>
+            <span class="s2">"pmt_number"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pmt_number</span><span class="p">,</span>
+            <span class="s2">"dom_number"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_number</span><span class="p">,</span>
+            <span class="s2">"pmt_dir_x"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_identity</span><span class="p">,</span>
+            <span class="s2">"pmt_dir_y"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_identity</span><span class="p">,</span>
+            <span class="s2">"pmt_dir_z"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_identity</span><span class="p">,</span>
+            <span class="s2">"dom_type"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_type</span><span class="p">,</span>
+        <span class="p">}</span>
+
+        <span class="k">return</span> <span class="n">feature_map</span></div>
+
+
+    <span class="k">def</span> <span class="nf">_dom_time</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="mf">2e04</span><span class="p">)</span> <span class="o">-</span> <span class="mf">1.0</span>
+
+    <span class="k">def</span> <span class="nf">_charge</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">log10</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="mf">2.0</span>
+
+    <span class="k">def</span> <span class="nf">_string</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="mf">50.0</span><span class="p">)</span> <span class="o">/</span> <span class="mf">50.0</span>
+
+    <span class="k">def</span> <span class="nf">_pmt_number</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">20.0</span>
+
+    <span class="k">def</span> <span class="nf">_dom_number</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="mf">60.0</span><span class="p">)</span> <span class="o">/</span> <span class="mf">60.0</span>
+
+    <span class="k">def</span> <span class="nf">_dom_type</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">130.0</span>
+
+    <span class="k">def</span> <span class="nf">_dom_xyz</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">500.0</span>
+
+    <span class="k">def</span> <span class="nf">_pmt_area</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">0.05</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/detector/prometheus.html b/_modules/graphnet/models/detector/prometheus.html
new file mode 100644
index 000000000..7f234af38
--- /dev/null
+++ b/_modules/graphnet/models/detector/prometheus.html
@@ -0,0 +1,395 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.detector.prometheus &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/detector/prometheus" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.detector.prometheus </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-detector-prometheus--page-root">Source code for graphnet.models.detector.prometheus</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Prometheus-specific `Detector` class(es)."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Callable</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.models.detector.detector</span> <span class="kn">import</span> <span class="n">Detector</span>
+
+
+<div class="viewcode-block" id="Prometheus">
+<a class="viewcode-back" href="../../../../api/graphnet.models.detector.prometheus.html#graphnet.models.detector.prometheus.Prometheus">[docs]</a>
+<span class="k">class</span> <span class="nc">Prometheus</span><span class="p">(</span><span class="n">Detector</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""`Detector` class for Prometheus prototype."""</span>
+
+<div class="viewcode-block" id="Prometheus.feature_map">
+<a class="viewcode-back" href="../../../../api/graphnet.models.detector.prometheus.html#graphnet.models.detector.prometheus.Prometheus.feature_map">[docs]</a>
+    <span class="k">def</span> <span class="nf">feature_map</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Map standardization functions to each dimension."""</span>
+        <span class="n">feature_map</span> <span class="o">=</span> <span class="p">{</span>
+            <span class="s2">"sensor_pos_x"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sensor_pos_xy</span><span class="p">,</span>
+            <span class="s2">"sensor_pos_y"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sensor_pos_xy</span><span class="p">,</span>
+            <span class="s2">"sensor_pos_z"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sensor_pos_z</span><span class="p">,</span>
+            <span class="s2">"t"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_t</span><span class="p">,</span>
+        <span class="p">}</span>
+        <span class="k">return</span> <span class="n">feature_map</span></div>
+
+
+    <span class="k">def</span> <span class="nf">_sensor_pos_xy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mi">100</span>
+
+    <span class="k">def</span> <span class="nf">_sensor_pos_z</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="mi">350</span><span class="p">)</span> <span class="o">/</span> <span class="mi">100</span>
+
+    <span class="k">def</span> <span class="nf">_t</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="p">((</span><span class="n">x</span> <span class="o">/</span> <span class="mf">1.05e04</span><span class="p">)</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">*</span> <span class="mf">20.0</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/gnn/convnet.html b/_modules/graphnet/models/gnn/convnet.html
new file mode 100644
index 000000000..ada586304
--- /dev/null
+++ b/_modules/graphnet/models/gnn/convnet.html
@@ -0,0 +1,486 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.gnn.convnet &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/gnn/convnet" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.gnn.convnet </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-gnn-convnet--page-root">Source code for graphnet.models.gnn.convnet</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Implementation of the ConvNet GNN model architecture.</span>
+
+<span class="sd">Author: Martin Ha Minh</span>
+<span class="sd">"""</span>
+
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span>
+<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">BatchNorm1d</span><span class="p">,</span> <span class="n">Linear</span><span class="p">,</span> <span class="n">Dropout</span>
+<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.nn</span> <span class="kn">import</span> <span class="n">TAGConv</span><span class="p">,</span> <span class="n">global_add_pool</span><span class="p">,</span> <span class="n">global_max_pool</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.gnn.gnn</span> <span class="kn">import</span> <span class="n">GNN</span>
+
+
+<div class="viewcode-block" id="ConvNet">
+<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.convnet.html#graphnet.models.gnn.convnet.ConvNet">[docs]</a>
+<span class="k">class</span> <span class="nc">ConvNet</span><span class="p">(</span><span class="n">GNN</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""ConvNet (convolutional network) model."""</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">nb_inputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="n">nb_outputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="n">nb_intermediate</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">128</span><span class="p">,</span>
+        <span class="n">dropout_ratio</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.3</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct `ConvNet`.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            nb_inputs: Number of input features, i.e. dimension of input</span>
+<span class="sd">                layer.</span>
+<span class="sd">            nb_outputs: Number of prediction labels, i.e. dimension of</span>
+<span class="sd">                output layer.</span>
+<span class="sd">            nb_intermediate: Number of nodes in intermediate layer(s).</span>
+<span class="sd">            dropout_ratio: Fraction of nodes to drop.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">nb_inputs</span><span class="p">,</span> <span class="n">nb_outputs</span><span class="p">)</span>
+
+        <span class="c1"># Member variables</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate</span> <span class="o">=</span> <span class="n">nb_intermediate</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span> <span class="o">=</span> <span class="mi">6</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate</span>
+
+        <span class="c1"># Architecture configuration</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">TAGConv</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">TAGConv</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">conv3</span> <span class="o">=</span> <span class="n">TAGConv</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">batchnorm1</span> <span class="o">=</span> <span class="n">BatchNorm1d</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">linear1</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">linear2</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">linear3</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">linear4</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">linear5</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">drop1</span> <span class="o">=</span> <span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_ratio</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">drop2</span> <span class="o">=</span> <span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_ratio</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">drop3</span> <span class="o">=</span> <span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_ratio</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">drop4</span> <span class="o">=</span> <span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_ratio</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">drop5</span> <span class="o">=</span> <span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_ratio</span><span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">out</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_outputs</span><span class="p">)</span>
+
+<div class="viewcode-block" id="ConvNet.forward">
+<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.convnet.html#graphnet.models.gnn.convnet.ConvNet.forward">[docs]</a>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Apply learnable forward pass."""</span>
+        <span class="c1"># Convenience variables</span>
+        <span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span>
+
+        <span class="c1"># Graph convolutional operations</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">))</span>
+        <span class="n">x1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
+            <span class="p">[</span>
+                <span class="n">global_add_pool</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">),</span>
+                <span class="n">global_max_pool</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">),</span>
+            <span class="p">],</span>
+            <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">))</span>
+        <span class="n">x2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
+            <span class="p">[</span>
+                <span class="n">global_add_pool</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">),</span>
+                <span class="n">global_max_pool</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">),</span>
+            <span class="p">],</span>
+            <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv3</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">))</span>
+        <span class="n">x3</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
+            <span class="p">[</span>
+                <span class="n">global_add_pool</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">),</span>
+                <span class="n">global_max_pool</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">),</span>
+            <span class="p">],</span>
+            <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Skip-cat</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">x1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">x3</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+
+        <span class="c1"># Batch-normalising intermediate features</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batchnorm1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+
+        <span class="c1"># Post-processing</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linear1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linear2</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linear3</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop3</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linear4</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop4</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linear5</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop5</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+
+        <span class="c1"># Read-out</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">out</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">x</span></div>
+</div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/gnn/dynedge.html b/_modules/graphnet/models/gnn/dynedge.html
new file mode 100644
index 000000000..d882546a3
--- /dev/null
+++ b/_modules/graphnet/models/gnn/dynedge.html
@@ -0,0 +1,693 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.gnn.dynedge &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/gnn/dynedge" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.gnn.dynedge </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-gnn-dynedge--page-root">Source code for graphnet.models.gnn.dynedge</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Implementation of the DynEdge GNN model architecture."""</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
+
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">LongTensor</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+<span class="kn">from</span> <span class="nn">torch_scatter</span> <span class="kn">import</span> <span class="n">scatter_max</span><span class="p">,</span> <span class="n">scatter_mean</span><span class="p">,</span> <span class="n">scatter_min</span><span class="p">,</span> <span class="n">scatter_sum</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.models.components.layers</span> <span class="kn">import</span> <span class="n">DynEdgeConv</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.gnn.gnn</span> <span class="kn">import</span> <span class="n">GNN</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.utils</span> <span class="kn">import</span> <span class="n">calculate_xyzt_homophily</span>
+
+<span class="n">GLOBAL_POOLINGS</span> <span class="o">=</span> <span class="p">{</span>
+    <span class="s2">"min"</span><span class="p">:</span> <span class="n">scatter_min</span><span class="p">,</span>
+    <span class="s2">"max"</span><span class="p">:</span> <span class="n">scatter_max</span><span class="p">,</span>
+    <span class="s2">"sum"</span><span class="p">:</span> <span class="n">scatter_sum</span><span class="p">,</span>
+    <span class="s2">"mean"</span><span class="p">:</span> <span class="n">scatter_mean</span><span class="p">,</span>
+<span class="p">}</span>
+
+
+<div class="viewcode-block" id="DynEdge">
+<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.dynedge.html#graphnet.models.gnn.dynedge.DynEdge">[docs]</a>
+<span class="k">class</span> <span class="nc">DynEdge</span><span class="p">(</span><span class="n">GNN</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""DynEdge (dynamical edge convolutional) model."""</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">nb_inputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="o">*</span><span class="p">,</span>
+        <span class="n">nb_neighbours</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span>
+        <span class="n">features_subset</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">slice</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">dynedge_layer_sizes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">post_processing_layer_sizes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">readout_layer_sizes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">global_pooling_schemes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">add_global_variables_after_pooling</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct `DynEdge`.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            nb_inputs: Number of input features on each node.</span>
+<span class="sd">            nb_neighbours: Number of neighbours to used in the k-nearest</span>
+<span class="sd">                neighbour clustering which is performed after each (dynamical)</span>
+<span class="sd">                edge convolution.</span>
+<span class="sd">            features_subset: The subset of latent features on each node that</span>
+<span class="sd">                are used as metric dimensions when performing the k-nearest</span>
+<span class="sd">                neighbours clustering. Defaults to [0,1,2].</span>
+<span class="sd">            dynedge_layer_sizes: The layer sizes, or latent feature dimenions,</span>
+<span class="sd">                used in the `DynEdgeConv` layer. Each entry in</span>
+<span class="sd">                `dynedge_layer_sizes` corresponds to a single `DynEdgeConv`</span>
+<span class="sd">                layer; the integers in the corresponding tuple corresponds to</span>
+<span class="sd">                the layer sizes in the multi-layer perceptron (MLP) that is</span>
+<span class="sd">                applied within each `DynEdgeConv` layer. That is, a list of</span>
+<span class="sd">                size-two tuples means that all `DynEdgeConv` layers contain a</span>
+<span class="sd">                two-layer MLP.</span>
+<span class="sd">                Defaults to [(128, 256), (336, 256), (336, 256), (336, 256)].</span>
+<span class="sd">            post_processing_layer_sizes: Hidden layer sizes in the MLP</span>
+<span class="sd">                following the skip-concatenation of the outputs of each</span>
+<span class="sd">                `DynEdgeConv` layer. Defaults to [336, 256].</span>
+<span class="sd">            readout_layer_sizes: Hidden layer sizes in the MLP following the</span>
+<span class="sd">                post-processing _and_ optional global pooling. As this is the</span>
+<span class="sd">                last layer(s) in the model, the last layer in the read-out</span>
+<span class="sd">                yields the output of the `DynEdge` model. Defaults to [128,].</span>
+<span class="sd">            global_pooling_schemes: The list global pooling schemes to use.</span>
+<span class="sd">                Options are: "min", "max", "mean", and "sum".</span>
+<span class="sd">            add_global_variables_after_pooling: Whether to add global variables</span>
+<span class="sd">                after global pooling. The alternative is to  added (distribute)</span>
+<span class="sd">                them to the individual nodes before any convolutional</span>
+<span class="sd">                operations.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Latent feature subset for computing nearest neighbours in DynEdge.</span>
+        <span class="k">if</span> <span class="n">features_subset</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">features_subset</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
+
+        <span class="c1"># DynEdge layer sizes</span>
+        <span class="k">if</span> <span class="n">dynedge_layer_sizes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">dynedge_layer_sizes</span> <span class="o">=</span> <span class="p">[</span>
+                <span class="p">(</span>
+                    <span class="mi">128</span><span class="p">,</span>
+                    <span class="mi">256</span><span class="p">,</span>
+                <span class="p">),</span>
+                <span class="p">(</span>
+                    <span class="mi">336</span><span class="p">,</span>
+                    <span class="mi">256</span><span class="p">,</span>
+                <span class="p">),</span>
+                <span class="p">(</span>
+                    <span class="mi">336</span><span class="p">,</span>
+                    <span class="mi">256</span><span class="p">,</span>
+                <span class="p">),</span>
+                <span class="p">(</span>
+                    <span class="mi">336</span><span class="p">,</span>
+                    <span class="mi">256</span><span class="p">,</span>
+                <span class="p">),</span>
+            <span class="p">]</span>
+
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dynedge_layer_sizes</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dynedge_layer_sizes</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="n">dynedge_layer_sizes</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">sizes</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="n">dynedge_layer_sizes</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span>
+            <span class="nb">all</span><span class="p">(</span><span class="n">size</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">size</span> <span class="ow">in</span> <span class="n">sizes</span><span class="p">)</span> <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="n">dynedge_layer_sizes</span>
+        <span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_dynedge_layer_sizes</span> <span class="o">=</span> <span class="n">dynedge_layer_sizes</span>
+
+        <span class="c1"># Post-processing layer sizes</span>
+        <span class="k">if</span> <span class="n">post_processing_layer_sizes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">post_processing_layer_sizes</span> <span class="o">=</span> <span class="p">[</span>
+                <span class="mi">336</span><span class="p">,</span>
+                <span class="mi">256</span><span class="p">,</span>
+            <span class="p">]</span>
+
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">post_processing_layer_sizes</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">post_processing_layer_sizes</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="n">size</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">size</span> <span class="ow">in</span> <span class="n">post_processing_layer_sizes</span><span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing_layer_sizes</span> <span class="o">=</span> <span class="n">post_processing_layer_sizes</span>
+
+        <span class="c1"># Read-out layer sizes</span>
+        <span class="k">if</span> <span class="n">readout_layer_sizes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">readout_layer_sizes</span> <span class="o">=</span> <span class="p">[</span>
+                <span class="mi">128</span><span class="p">,</span>
+            <span class="p">]</span>
+
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">readout_layer_sizes</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">readout_layer_sizes</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="n">size</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">size</span> <span class="ow">in</span> <span class="n">readout_layer_sizes</span><span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_readout_layer_sizes</span> <span class="o">=</span> <span class="n">readout_layer_sizes</span>
+
+        <span class="c1"># Global pooling scheme(s)</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">global_pooling_schemes</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
+            <span class="n">global_pooling_schemes</span> <span class="o">=</span> <span class="p">[</span><span class="n">global_pooling_schemes</span><span class="p">]</span>
+
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">global_pooling_schemes</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
+            <span class="k">for</span> <span class="n">pooling_scheme</span> <span class="ow">in</span> <span class="n">global_pooling_schemes</span><span class="p">:</span>
+                <span class="k">assert</span> <span class="p">(</span>
+                    <span class="n">pooling_scheme</span> <span class="ow">in</span> <span class="n">GLOBAL_POOLINGS</span>
+                <span class="p">),</span> <span class="sa">f</span><span class="s2">"Global pooling scheme </span><span class="si">{</span><span class="n">pooling_scheme</span><span class="si">}</span><span class="s2"> not supported."</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">assert</span> <span class="n">global_pooling_schemes</span> <span class="ow">is</span> <span class="kc">None</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span> <span class="o">=</span> <span class="n">global_pooling_schemes</span>
+
+        <span class="k">if</span> <span class="n">add_global_variables_after_pooling</span><span class="p">:</span>
+            <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span><span class="p">,</span> <span class="p">(</span>
+                <span class="s2">"No global pooling schemes were request, so cannot add global"</span>
+                <span class="s2">" variables after pooling."</span>
+            <span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_add_global_variables_after_pooling</span> <span class="o">=</span> <span class="p">(</span>
+            <span class="n">add_global_variables_after_pooling</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">nb_inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_readout_layer_sizes</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
+
+        <span class="c1"># Remaining member variables()</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_activation</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">()</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span> <span class="o">=</span> <span class="n">nb_inputs</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_nb_global_variables</span> <span class="o">=</span> <span class="mi">5</span> <span class="o">+</span> <span class="n">nb_inputs</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_nb_neighbours</span> <span class="o">=</span> <span class="n">nb_neighbours</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_features_subset</span> <span class="o">=</span> <span class="n">features_subset</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_construct_layers</span><span class="p">()</span>
+
+    <span class="k">def</span> <span class="nf">_construct_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct layers (torch.nn.Modules)."""</span>
+        <span class="c1"># Convolutional operations</span>
+        <span class="n">nb_input_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span>
+        <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_global_variables_after_pooling</span><span class="p">:</span>
+            <span class="n">nb_input_features</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_global_variables</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_conv_layers</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">()</span>
+        <span class="n">nb_latent_features</span> <span class="o">=</span> <span class="n">nb_input_features</span>
+        <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dynedge_layer_sizes</span><span class="p">:</span>
+            <span class="n">layers</span> <span class="o">=</span> <span class="p">[]</span>
+            <span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">nb_latent_features</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="n">sizes</span><span class="p">)</span>
+            <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span>
+                <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span>
+            <span class="p">):</span>
+                <span class="k">if</span> <span class="n">ix</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
+                    <span class="n">nb_in</span> <span class="o">*=</span> <span class="mi">2</span>
+                <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">))</span>
+                <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_activation</span><span class="p">)</span>
+
+            <span class="n">conv_layer</span> <span class="o">=</span> <span class="n">DynEdgeConv</span><span class="p">(</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">),</span>
+                <span class="n">aggr</span><span class="o">=</span><span class="s2">"add"</span><span class="p">,</span>
+                <span class="n">nb_neighbors</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_nb_neighbours</span><span class="p">,</span>
+                <span class="n">features_subset</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_features_subset</span><span class="p">,</span>
+            <span class="p">)</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_conv_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">conv_layer</span><span class="p">)</span>
+
+            <span class="n">nb_latent_features</span> <span class="o">=</span> <span class="n">nb_out</span>
+
+        <span class="c1"># Post-processing operations</span>
+        <span class="n">nb_latent_features</span> <span class="o">=</span> <span class="p">(</span>
+            <span class="nb">sum</span><span class="p">(</span><span class="n">sizes</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dynedge_layer_sizes</span><span class="p">)</span>
+            <span class="o">+</span> <span class="n">nb_input_features</span>
+        <span class="p">)</span>
+
+        <span class="n">post_processing_layers</span> <span class="o">=</span> <span class="p">[]</span>
+        <span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">nb_latent_features</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing_layer_sizes</span>
+        <span class="p">)</span>
+        <span class="k">for</span> <span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:]):</span>
+            <span class="n">post_processing_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">))</span>
+            <span class="n">post_processing_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_activation</span><span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">post_processing_layers</span><span class="p">)</span>
+
+        <span class="c1"># Read-out operations</span>
+        <span class="n">nb_poolings</span> <span class="o">=</span> <span class="p">(</span>
+            <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span><span class="p">)</span>
+            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span>
+            <span class="k">else</span> <span class="mi">1</span>
+        <span class="p">)</span>
+        <span class="n">nb_latent_features</span> <span class="o">=</span> <span class="n">nb_out</span> <span class="o">*</span> <span class="n">nb_poolings</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_global_variables_after_pooling</span><span class="p">:</span>
+            <span class="n">nb_latent_features</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_global_variables</span>
+
+        <span class="n">readout_layers</span> <span class="o">=</span> <span class="p">[]</span>
+        <span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">nb_latent_features</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_readout_layer_sizes</span><span class="p">)</span>
+        <span class="k">for</span> <span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:]):</span>
+            <span class="n">readout_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">))</span>
+            <span class="n">readout_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_activation</span><span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_readout</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">readout_layers</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_global_pooling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Perform global pooling."""</span>
+        <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span>
+        <span class="n">pooled</span> <span class="o">=</span> <span class="p">[]</span>
+        <span class="k">for</span> <span class="n">pooling_scheme</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span><span class="p">:</span>
+            <span class="n">pooling_fn</span> <span class="o">=</span> <span class="n">GLOBAL_POOLINGS</span><span class="p">[</span><span class="n">pooling_scheme</span><span class="p">]</span>
+            <span class="n">pooled_x</span> <span class="o">=</span> <span class="n">pooling_fn</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">index</span><span class="o">=</span><span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
+            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pooled_x</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">pooled_x</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
+                <span class="c1"># `scatter_{min,max}`, which return also an argument, vs.</span>
+                <span class="c1"># `scatter_{mean,sum}`</span>
+                <span class="n">pooled_x</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">pooled_x</span>
+            <span class="n">pooled</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">pooled_x</span><span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">pooled</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_calculate_global_variables</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">edge_index</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
+        <span class="n">batch</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
+        <span class="o">*</span><span class="n">additional_attributes</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Calculate global variables."""</span>
+        <span class="c1"># Calculate homophily (scalar variables)</span>
+        <span class="n">h_x</span><span class="p">,</span> <span class="n">h_y</span><span class="p">,</span> <span class="n">h_z</span><span class="p">,</span> <span class="n">h_t</span> <span class="o">=</span> <span class="n">calculate_xyzt_homophily</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+
+        <span class="c1"># Calculate mean features</span>
+        <span class="n">global_means</span> <span class="o">=</span> <span class="n">scatter_mean</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
+
+        <span class="c1"># Add global variables</span>
+        <span class="n">global_variables</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
+            <span class="p">[</span>
+                <span class="n">global_means</span><span class="p">,</span>
+                <span class="n">h_x</span><span class="p">,</span>
+                <span class="n">h_y</span><span class="p">,</span>
+                <span class="n">h_z</span><span class="p">,</span>
+                <span class="n">h_t</span><span class="p">,</span>
+            <span class="p">]</span>
+            <span class="o">+</span> <span class="p">[</span><span class="n">attr</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">attr</span> <span class="ow">in</span> <span class="n">additional_attributes</span><span class="p">],</span>
+            <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">global_variables</span>
+
+<div class="viewcode-block" id="DynEdge.forward">
+<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.dynedge.html#graphnet.models.gnn.dynedge.DynEdge.forward">[docs]</a>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Apply learnable forward pass."""</span>
+        <span class="c1"># Convenience variables</span>
+        <span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span>
+
+        <span class="n">global_variables</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_calculate_global_variables</span><span class="p">(</span>
+            <span class="n">x</span><span class="p">,</span>
+            <span class="n">edge_index</span><span class="p">,</span>
+            <span class="n">batch</span><span class="p">,</span>
+            <span class="n">torch</span><span class="o">.</span><span class="n">log10</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">n_pulses</span><span class="p">),</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Distribute global variables out to each node</span>
+        <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_global_variables_after_pooling</span><span class="p">:</span>
+            <span class="n">distribute</span> <span class="o">=</span> <span class="p">(</span>
+                <span class="n">batch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
+            <span class="p">)</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
+
+            <span class="n">global_variables_distributed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span>
+                <span class="n">distribute</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
+                <span class="o">*</span> <span class="n">global_variables</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span>
+                <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+            <span class="p">)</span>
+
+            <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">global_variables_distributed</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+
+        <span class="c1"># DynEdge-convolutions</span>
+        <span class="n">skip_connections</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">]</span>
+        <span class="k">for</span> <span class="n">conv_layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conv_layers</span><span class="p">:</span>
+            <span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span> <span class="o">=</span> <span class="n">conv_layer</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+            <span class="n">skip_connections</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+
+        <span class="c1"># Skip-cat</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">skip_connections</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+
+        <span class="c1"># Post-processing</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+
+        <span class="c1"># (Optional) Global pooling</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span><span class="p">:</span>
+            <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="o">=</span><span class="n">batch</span><span class="p">)</span>
+            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_global_variables_after_pooling</span><span class="p">:</span>
+                <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
+                    <span class="p">[</span>
+                        <span class="n">x</span><span class="p">,</span>
+                        <span class="n">global_variables</span><span class="p">,</span>
+                    <span class="p">],</span>
+                    <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+                <span class="p">)</span>
+
+        <span class="c1"># Read-out</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_readout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">x</span></div>
+</div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/gnn/dynedge_jinst.html b/_modules/graphnet/models/gnn/dynedge_jinst.html
new file mode 100644
index 000000000..125dc1399
--- /dev/null
+++ b/_modules/graphnet/models/gnn/dynedge_jinst.html
@@ -0,0 +1,521 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.gnn.dynedge_jinst &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/gnn/dynedge_jinst" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.gnn.dynedge_jinst </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-gnn-dynedge-jinst--page-root">Source code for graphnet.models.gnn.dynedge_jinst</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Implementation of the exact DynEdge architecture used in [2209.03042].</span>
+
+<span class="sd">Author: Rasmus Oersoe</span>
+<span class="sd">"""</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span>
+
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+<span class="kn">from</span> <span class="nn">torch_scatter</span> <span class="kn">import</span> <span class="n">scatter_max</span><span class="p">,</span> <span class="n">scatter_mean</span><span class="p">,</span> <span class="n">scatter_min</span><span class="p">,</span> <span class="n">scatter_sum</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.models.components.layers</span> <span class="kn">import</span> <span class="n">DynEdgeConv</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.gnn.gnn</span> <span class="kn">import</span> <span class="n">GNN</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.utils</span> <span class="kn">import</span> <span class="n">calculate_xyzt_homophily</span>
+
+
+<div class="viewcode-block" id="DynEdgeJINST">
+<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.dynedge_jinst.html#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST">[docs]</a>
+<span class="k">class</span> <span class="nc">DynEdgeJINST</span><span class="p">(</span><span class="n">GNN</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""DynEdge (dynamical edge convolutional) model used in [2209.03042]."""</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">nb_inputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="n">layer_size_scale</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct `DynEdgeJINST`.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            nb_inputs: Number of input features.</span>
+<span class="sd">            nb_outputs: Number of output features.</span>
+<span class="sd">            layer_size_scale: Integer that scales the size of hidden layers.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Architecture configuration</span>
+        <span class="n">c</span> <span class="o">=</span> <span class="n">layer_size_scale</span>
+        <span class="n">l1</span><span class="p">,</span> <span class="n">l2</span><span class="p">,</span> <span class="n">l3</span><span class="p">,</span> <span class="n">l4</span><span class="p">,</span> <span class="n">l5</span><span class="p">,</span> <span class="n">l6</span> <span class="o">=</span> <span class="p">(</span>
+            <span class="n">nb_inputs</span><span class="p">,</span>
+            <span class="n">c</span> <span class="o">*</span> <span class="mi">16</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span>
+            <span class="n">c</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span>
+            <span class="n">c</span> <span class="o">*</span> <span class="mi">42</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span>
+            <span class="n">c</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span>
+            <span class="n">c</span> <span class="o">*</span> <span class="mi">16</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">nb_inputs</span><span class="p">,</span> <span class="n">l6</span><span class="p">)</span>
+
+        <span class="c1"># Graph convolutional operations</span>
+        <span class="n">features_subset</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
+        <span class="n">nb_neighbors</span> <span class="o">=</span> <span class="mi">8</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">conv_add1</span> <span class="o">=</span> <span class="n">DynEdgeConv</span><span class="p">(</span>
+            <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l1</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">l2</span><span class="p">),</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l2</span><span class="p">,</span> <span class="n">l3</span><span class="p">),</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span>
+            <span class="p">),</span>
+            <span class="n">aggr</span><span class="o">=</span><span class="s2">"add"</span><span class="p">,</span>
+            <span class="n">nb_neighbors</span><span class="o">=</span><span class="n">nb_neighbors</span><span class="p">,</span>
+            <span class="n">features_subset</span><span class="o">=</span><span class="n">features_subset</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">conv_add2</span> <span class="o">=</span> <span class="n">DynEdgeConv</span><span class="p">(</span>
+            <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l3</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">l4</span><span class="p">),</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l4</span><span class="p">,</span> <span class="n">l3</span><span class="p">),</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span>
+            <span class="p">),</span>
+            <span class="n">aggr</span><span class="o">=</span><span class="s2">"add"</span><span class="p">,</span>
+            <span class="n">nb_neighbors</span><span class="o">=</span><span class="n">nb_neighbors</span><span class="p">,</span>
+            <span class="n">features_subset</span><span class="o">=</span><span class="n">features_subset</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">conv_add3</span> <span class="o">=</span> <span class="n">DynEdgeConv</span><span class="p">(</span>
+            <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l3</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">l4</span><span class="p">),</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l4</span><span class="p">,</span> <span class="n">l3</span><span class="p">),</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span>
+            <span class="p">),</span>
+            <span class="n">aggr</span><span class="o">=</span><span class="s2">"add"</span><span class="p">,</span>
+            <span class="n">nb_neighbors</span><span class="o">=</span><span class="n">nb_neighbors</span><span class="p">,</span>
+            <span class="n">features_subset</span><span class="o">=</span><span class="n">features_subset</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">conv_add4</span> <span class="o">=</span> <span class="n">DynEdgeConv</span><span class="p">(</span>
+            <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l3</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">l4</span><span class="p">),</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l4</span><span class="p">,</span> <span class="n">l3</span><span class="p">),</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span>
+            <span class="p">),</span>
+            <span class="n">aggr</span><span class="o">=</span><span class="s2">"add"</span><span class="p">,</span>
+            <span class="n">nb_neighbors</span><span class="o">=</span><span class="n">nb_neighbors</span><span class="p">,</span>
+            <span class="n">features_subset</span><span class="o">=</span><span class="n">features_subset</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Post-processing operations</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">nn1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l3</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="n">l1</span><span class="p">,</span> <span class="n">l4</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">nn2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l4</span><span class="p">,</span> <span class="n">l5</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">nn3</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="n">l5</span> <span class="o">+</span> <span class="mi">5</span><span class="p">,</span> <span class="n">l6</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">lrelu</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">()</span>
+
+<div class="viewcode-block" id="DynEdgeJINST.forward">
+<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.dynedge_jinst.html#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward">[docs]</a>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Apply learnable forward pass."""</span>
+        <span class="c1"># Convenience variables</span>
+        <span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span>
+
+        <span class="c1"># Calculate homophily (scalar variables)</span>
+        <span class="n">h_x</span><span class="p">,</span> <span class="n">h_y</span><span class="p">,</span> <span class="n">h_z</span><span class="p">,</span> <span class="n">h_t</span> <span class="o">=</span> <span class="n">calculate_xyzt_homophily</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+
+        <span class="n">a</span><span class="p">,</span> <span class="n">edge_index</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_add1</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+        <span class="n">b</span><span class="p">,</span> <span class="n">edge_index</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_add2</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+        <span class="n">c</span><span class="p">,</span> <span class="n">edge_index</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_add3</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+        <span class="n">d</span><span class="p">,</span> <span class="n">edge_index</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_add4</span><span class="p">(</span><span class="n">c</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+
+        <span class="c1"># Skip-cat</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">d</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+
+        <span class="c1"># Post-processing</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">nn1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lrelu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">nn2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+
+        <span class="c1"># Aggregation across nodes</span>
+        <span class="n">a</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">scatter_max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
+        <span class="n">b</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">scatter_min</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
+        <span class="n">c</span> <span class="o">=</span> <span class="n">scatter_sum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
+        <span class="n">d</span> <span class="o">=</span> <span class="n">scatter_mean</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
+
+        <span class="c1"># Concatenate aggregations and scalar features</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
+            <span class="p">(</span>
+                <span class="n">a</span><span class="p">,</span>
+                <span class="n">b</span><span class="p">,</span>
+                <span class="n">c</span><span class="p">,</span>
+                <span class="n">d</span><span class="p">,</span>
+                <span class="n">h_t</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
+                <span class="n">h_x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
+                <span class="n">h_y</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
+                <span class="n">h_z</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
+                <span class="n">data</span><span class="o">.</span><span class="n">n_pulses</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
+            <span class="p">),</span>
+            <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Read-out</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lrelu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">nn3</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lrelu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">x</span></div>
+</div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/gnn/dynedge_kaggle_tito.html b/_modules/graphnet/models/gnn/dynedge_kaggle_tito.html
new file mode 100644
index 000000000..8895fa50c
--- /dev/null
+++ b/_modules/graphnet/models/gnn/dynedge_kaggle_tito.html
@@ -0,0 +1,618 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.gnn.dynedge_kaggle_tito &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/gnn/dynedge_kaggle_tito" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.gnn.dynedge_kaggle_tito </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-gnn-dynedge-kaggle-tito--page-root">Source code for graphnet.models.gnn.dynedge_kaggle_tito</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Implementation of DynEdge architecture used in.</span>
+
+<span class="sd">                    IceCube - Neutrinos in Deep Ice</span>
+<span class="sd">Reconstruct the direction of neutrinos from the Universe to the South Pole</span>
+
+<span class="sd">Kaggle competition.</span>
+
+<span class="sd">Solution by TITO.</span>
+<span class="sd">"""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Optional</span>
+
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">LongTensor</span>
+
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.utils</span> <span class="kn">import</span> <span class="n">to_dense_batch</span>
+<span class="kn">from</span> <span class="nn">torch_scatter</span> <span class="kn">import</span> <span class="n">scatter_max</span><span class="p">,</span> <span class="n">scatter_mean</span><span class="p">,</span> <span class="n">scatter_min</span><span class="p">,</span> <span class="n">scatter_sum</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.models.components.layers</span> <span class="kn">import</span> <span class="n">DynTrans</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.gnn.gnn</span> <span class="kn">import</span> <span class="n">GNN</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.utils</span> <span class="kn">import</span> <span class="n">calculate_xyzt_homophily</span>
+
+<span class="n">GLOBAL_POOLINGS</span> <span class="o">=</span> <span class="p">{</span>
+    <span class="s2">"min"</span><span class="p">:</span> <span class="n">scatter_min</span><span class="p">,</span>
+    <span class="s2">"max"</span><span class="p">:</span> <span class="n">scatter_max</span><span class="p">,</span>
+    <span class="s2">"sum"</span><span class="p">:</span> <span class="n">scatter_sum</span><span class="p">,</span>
+    <span class="s2">"mean"</span><span class="p">:</span> <span class="n">scatter_mean</span><span class="p">,</span>
+<span class="p">}</span>
+
+
+<div class="viewcode-block" id="DynEdgeTITO">
+<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.dynedge_kaggle_tito.html#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO">[docs]</a>
+<span class="k">class</span> <span class="nc">DynEdgeTITO</span><span class="p">(</span><span class="n">GNN</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""DynEdge (dynamical edge convolutional) model."""</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">nb_inputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="n">features_subset</span><span class="p">:</span> <span class="nb">slice</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span>
+        <span class="n">dyntrans_layer_sizes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">global_pooling_schemes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"max"</span><span class="p">],</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct `DynEdge`.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            nb_inputs: Number of input features on each node.</span>
+<span class="sd">            features_subset: The subset of latent features on each node that</span>
+<span class="sd">                are used as metric dimensions when performing the k-nearest</span>
+<span class="sd">                neighbours clustering. Defaults to [0,1,2,3].</span>
+<span class="sd">            dyntrans_layer_sizes: The layer sizes, or latent feature dimenions,</span>
+<span class="sd">                used in the `DynTrans` layer.</span>
+<span class="sd">            global_pooling_schemes: The list global pooling schemes to use.</span>
+<span class="sd">                Options are: "min", "max", "mean", and "sum".</span>
+<span class="sd">        """</span>
+        <span class="c1"># DynEdge layer sizes</span>
+        <span class="k">if</span> <span class="n">dyntrans_layer_sizes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">dyntrans_layer_sizes</span> <span class="o">=</span> <span class="p">[</span>
+                <span class="p">(</span>
+                    <span class="mi">256</span><span class="p">,</span>
+                    <span class="mi">256</span><span class="p">,</span>
+                <span class="p">),</span>
+                <span class="p">(</span>
+                    <span class="mi">256</span><span class="p">,</span>
+                    <span class="mi">256</span><span class="p">,</span>
+                <span class="p">),</span>
+                <span class="p">(</span>
+                    <span class="mi">256</span><span class="p">,</span>
+                    <span class="mi">256</span><span class="p">,</span>
+                <span class="p">),</span>
+            <span class="p">]</span>
+
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dyntrans_layer_sizes</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dyntrans_layer_sizes</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="n">dyntrans_layer_sizes</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">sizes</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="n">dyntrans_layer_sizes</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span>
+            <span class="nb">all</span><span class="p">(</span><span class="n">size</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">size</span> <span class="ow">in</span> <span class="n">sizes</span><span class="p">)</span> <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="n">dyntrans_layer_sizes</span>
+        <span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_dyntrans_layer_sizes</span> <span class="o">=</span> <span class="n">dyntrans_layer_sizes</span>
+
+        <span class="c1"># Post-processing layer sizes</span>
+        <span class="n">post_processing_layer_sizes</span> <span class="o">=</span> <span class="p">[</span>
+            <span class="mi">336</span><span class="p">,</span>
+            <span class="mi">256</span><span class="p">,</span>
+        <span class="p">]</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing_layer_sizes</span> <span class="o">=</span> <span class="n">post_processing_layer_sizes</span>
+
+        <span class="c1"># Read-out layer sizes</span>
+        <span class="n">readout_layer_sizes</span> <span class="o">=</span> <span class="p">[</span>
+            <span class="mi">256</span><span class="p">,</span>
+            <span class="mi">128</span><span class="p">,</span>
+        <span class="p">]</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_readout_layer_sizes</span> <span class="o">=</span> <span class="n">readout_layer_sizes</span>
+
+        <span class="c1"># Global pooling scheme(s)</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">global_pooling_schemes</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
+            <span class="n">global_pooling_schemes</span> <span class="o">=</span> <span class="p">[</span><span class="n">global_pooling_schemes</span><span class="p">]</span>
+
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">global_pooling_schemes</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
+            <span class="k">for</span> <span class="n">pooling_scheme</span> <span class="ow">in</span> <span class="n">global_pooling_schemes</span><span class="p">:</span>
+                <span class="k">assert</span> <span class="p">(</span>
+                    <span class="n">pooling_scheme</span> <span class="ow">in</span> <span class="n">GLOBAL_POOLINGS</span>
+                <span class="p">),</span> <span class="sa">f</span><span class="s2">"Global pooling scheme </span><span class="si">{</span><span class="n">pooling_scheme</span><span class="si">}</span><span class="s2"> not supported."</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">assert</span> <span class="n">global_pooling_schemes</span> <span class="ow">is</span> <span class="kc">None</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span> <span class="o">=</span> <span class="n">global_pooling_schemes</span>
+
+        <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span><span class="p">,</span> <span class="p">(</span>
+            <span class="s2">"No global pooling schemes were request, so cannot add global"</span>
+            <span class="s2">" variables after pooling."</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">nb_inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_readout_layer_sizes</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
+
+        <span class="c1"># Remaining member variables()</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_activation</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">()</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span> <span class="o">=</span> <span class="n">nb_inputs</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_nb_global_variables</span> <span class="o">=</span> <span class="mi">5</span> <span class="o">+</span> <span class="n">nb_inputs</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_features_subset</span> <span class="o">=</span> <span class="n">features_subset</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_construct_layers</span><span class="p">()</span>
+
+    <span class="k">def</span> <span class="nf">_construct_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct layers (torch.nn.Modules)."""</span>
+        <span class="c1"># Convolutional operations</span>
+        <span class="n">nb_input_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_conv_layers</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">()</span>
+        <span class="n">nb_latent_features</span> <span class="o">=</span> <span class="n">nb_input_features</span>
+        <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dyntrans_layer_sizes</span><span class="p">:</span>
+            <span class="n">conv_layer</span> <span class="o">=</span> <span class="n">DynTrans</span><span class="p">(</span>
+                <span class="p">[</span><span class="n">nb_latent_features</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="n">sizes</span><span class="p">),</span>
+                <span class="n">aggr</span><span class="o">=</span><span class="s2">"max"</span><span class="p">,</span>
+                <span class="n">features_subset</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_features_subset</span><span class="p">,</span>
+                <span class="n">n_head</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span>
+            <span class="p">)</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_conv_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">conv_layer</span><span class="p">)</span>
+            <span class="n">nb_latent_features</span> <span class="o">=</span> <span class="n">sizes</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
+
+        <span class="n">post_processing_layers</span> <span class="o">=</span> <span class="p">[]</span>
+        <span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">nb_latent_features</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing_layer_sizes</span>
+        <span class="p">)</span>
+        <span class="k">for</span> <span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:]):</span>
+            <span class="n">post_processing_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">))</span>
+            <span class="n">post_processing_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_activation</span><span class="p">)</span>
+        <span class="n">last_posting_layer_output_dim</span> <span class="o">=</span> <span class="n">nb_out</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">post_processing_layers</span><span class="p">)</span>
+
+        <span class="c1"># Read-out operations</span>
+        <span class="n">nb_poolings</span> <span class="o">=</span> <span class="p">(</span>
+            <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span><span class="p">)</span>
+            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span>
+            <span class="k">else</span> <span class="mi">1</span>
+        <span class="p">)</span>
+        <span class="n">nb_latent_features</span> <span class="o">=</span> <span class="n">last_posting_layer_output_dim</span> <span class="o">*</span> <span class="n">nb_poolings</span>
+        <span class="n">nb_latent_features</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_global_variables</span>
+
+        <span class="n">readout_layers</span> <span class="o">=</span> <span class="p">[]</span>
+        <span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">nb_latent_features</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_readout_layer_sizes</span><span class="p">)</span>
+        <span class="k">for</span> <span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:]):</span>
+            <span class="n">readout_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">))</span>
+            <span class="n">readout_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_activation</span><span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_readout</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">readout_layers</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_global_pooling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Perform global pooling."""</span>
+        <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span>
+        <span class="n">pooled</span> <span class="o">=</span> <span class="p">[]</span>
+        <span class="k">for</span> <span class="n">pooling_scheme</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span><span class="p">:</span>
+            <span class="n">pooling_fn</span> <span class="o">=</span> <span class="n">GLOBAL_POOLINGS</span><span class="p">[</span><span class="n">pooling_scheme</span><span class="p">]</span>
+            <span class="n">pooled_x</span> <span class="o">=</span> <span class="n">pooling_fn</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">index</span><span class="o">=</span><span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
+            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pooled_x</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">pooled_x</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
+                <span class="c1"># `scatter_{min,max}`, which return also an argument, vs.</span>
+                <span class="c1"># `scatter_{mean,sum}`</span>
+                <span class="n">pooled_x</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">pooled_x</span>
+            <span class="n">pooled</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">pooled_x</span><span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">pooled</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_calculate_global_variables</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">edge_index</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
+        <span class="n">batch</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span>
+        <span class="o">*</span><span class="n">additional_attributes</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Calculate global variables."""</span>
+        <span class="c1"># Calculate homophily (scalar variables)</span>
+        <span class="n">h_x</span><span class="p">,</span> <span class="n">h_y</span><span class="p">,</span> <span class="n">h_z</span><span class="p">,</span> <span class="n">h_t</span> <span class="o">=</span> <span class="n">calculate_xyzt_homophily</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+
+        <span class="c1"># Calculate mean features</span>
+        <span class="n">global_means</span> <span class="o">=</span> <span class="n">scatter_mean</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
+
+        <span class="c1"># Add global variables</span>
+        <span class="n">global_variables</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
+            <span class="p">[</span>
+                <span class="n">global_means</span><span class="p">,</span>
+                <span class="n">h_x</span><span class="p">,</span>
+                <span class="n">h_y</span><span class="p">,</span>
+                <span class="n">h_z</span><span class="p">,</span>
+                <span class="n">h_t</span><span class="p">,</span>
+            <span class="p">]</span>
+            <span class="o">+</span> <span class="p">[</span><span class="n">attr</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">attr</span> <span class="ow">in</span> <span class="n">additional_attributes</span><span class="p">],</span>
+            <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">global_variables</span>
+
+<div class="viewcode-block" id="DynEdgeTITO.forward">
+<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.dynedge_kaggle_tito.html#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward">[docs]</a>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Apply learnable forward pass."""</span>
+        <span class="c1"># Convenience variables</span>
+        <span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span>
+
+        <span class="n">global_variables</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_calculate_global_variables</span><span class="p">(</span>
+            <span class="n">x</span><span class="p">,</span>
+            <span class="n">edge_index</span><span class="p">,</span>
+            <span class="n">batch</span><span class="p">,</span>
+            <span class="n">torch</span><span class="o">.</span><span class="n">log10</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">n_pulses</span><span class="p">),</span>
+        <span class="p">)</span>
+
+        <span class="c1"># DynEdge-convolutions</span>
+        <span class="k">for</span> <span class="n">conv_layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conv_layers</span><span class="p">:</span>
+            <span class="n">x</span> <span class="o">=</span> <span class="n">conv_layer</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+
+        <span class="n">x</span><span class="p">,</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">to_dense_batch</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span>
+
+        <span class="c1"># Post-processing</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+
+        <span class="c1"># (Optional) Global pooling</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="o">=</span><span class="n">batch</span><span class="p">)</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
+            <span class="p">[</span>
+                <span class="n">x</span><span class="p">,</span>
+                <span class="n">global_variables</span><span class="p">,</span>
+            <span class="p">],</span>
+            <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Read-out</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_readout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">x</span></div>
+</div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/gnn/gnn.html b/_modules/graphnet/models/gnn/gnn.html
new file mode 100644
index 000000000..7a3a22f46
--- /dev/null
+++ b/_modules/graphnet/models/gnn/gnn.html
@@ -0,0 +1,403 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.gnn.gnn &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/gnn/gnn" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.gnn.gnn </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-gnn-gnn--page-root">Source code for graphnet.models.gnn.gnn</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Base GNN-specific `Model` class(es)."""</span>
+
+<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span>
+
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span>
+
+
+<div class="viewcode-block" id="GNN">
+<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN">[docs]</a>
+<span class="k">class</span> <span class="nc">GNN</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Base class for all core GNN models in graphnet."""</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">nb_inputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">nb_outputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct `GNN`."""</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
+
+        <span class="c1"># Member variables</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span> <span class="o">=</span> <span class="n">nb_inputs</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_nb_outputs</span> <span class="o">=</span> <span class="n">nb_outputs</span>
+
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">nb_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Return number of input features."""</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span>
+
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">nb_outputs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Return number of output features."""</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_outputs</span>
+
+<div class="viewcode-block" id="GNN.forward">
+<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN.forward">[docs]</a>
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Apply learnable forward pass in model."""</span></div>
+</div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/graphs/edges/edges.html b/_modules/graphnet/models/graphs/edges/edges.html
new file mode 100644
index 000000000..e681629cf
--- /dev/null
+++ b/_modules/graphnet/models/graphs/edges/edges.html
@@ -0,0 +1,563 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.graphs.edges.edges &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/graphs/edges/edges" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.graphs.edges.edges </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../../"versions.json"",
+        target_loc = "../../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-graphs-edges-edges--page-root">Source code for graphnet.models.graphs.edges.edges</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Class(es) for building/connecting graphs."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
+<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span><span class="p">,</span> <span class="n">ABC</span>
+
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.nn</span> <span class="kn">import</span> <span class="n">knn_graph</span><span class="p">,</span> <span class="n">radius_graph</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.utils</span> <span class="kn">import</span> <span class="n">calculate_distance_matrix</span>
+<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span>
+
+
+<div class="viewcode-block" id="EdgeDefinition">
+<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EdgeDefinition">[docs]</a>
+<span class="k">class</span> <span class="nc">EdgeDefinition</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span>  <span class="c1"># pylint: disable=too-few-public-methods</span>
+<span class="w">    </span><span class="sd">"""Base class for graph building."""</span>
+
+<div class="viewcode-block" id="EdgeDefinition.forward">
+<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EdgeDefinition.forward">[docs]</a>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct edges based on problem specific implementation of.</span>
+
+<span class="sd">        ´_construct_edges´</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            graph: a graph without edges</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            graph: a graph with edges</span>
+<span class="sd">        """</span>
+        <span class="k">if</span> <span class="n">graph</span><span class="o">.</span><span class="n">edge_index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">warnonce</span><span class="p">(</span>
+                <span class="s2">"GraphBuilder received graph with pre-existing "</span>
+                <span class="s2">"structure. Will overwrite."</span>
+            <span class="p">)</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_construct_edges</span><span class="p">(</span><span class="n">graph</span><span class="p">)</span></div>
+
+
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">_construct_edges</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct edges and assign them to graph. I.e. ´graph.edge_index = edge_index´.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            graph: graph without edges</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            graph: graph with edges assigned.</span>
+<span class="sd">        """</span></div>
+
+
+
+<div class="viewcode-block" id="KNNEdges">
+<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.KNNEdges">[docs]</a>
+<span class="k">class</span> <span class="nc">KNNEdges</span><span class="p">(</span><span class="n">EdgeDefinition</span><span class="p">):</span>  <span class="c1"># pylint: disable=too-few-public-methods</span>
+<span class="w">    </span><span class="sd">"""Builds edges from the k-nearest neighbours."""</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">nb_nearest_neighbours</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="n">columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""K-NN Edge definition.</span>
+
+<span class="sd">        Will connect nodes together with their ´nb_nearest_neighbours´</span>
+<span class="sd">        nearest neighbours in the feature space given by ´columns´.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            nb_nearest_neighbours: number of neighbours.</span>
+<span class="sd">            columns: Node features to use for distance calculation.</span>
+<span class="sd">            Defaults to [0,1,2].</span>
+<span class="sd">        """</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>
+
+        <span class="c1"># Member variable(s)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_nb_nearest_neighbours</span> <span class="o">=</span> <span class="n">nb_nearest_neighbours</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_columns</span> <span class="o">=</span> <span class="n">columns</span>
+
+    <span class="k">def</span> <span class="nf">_construct_edges</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Define K-NN edges."""</span>
+        <span class="n">graph</span><span class="o">.</span><span class="n">edge_index</span> <span class="o">=</span> <span class="n">knn_graph</span><span class="p">(</span>
+            <span class="n">graph</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_columns</span><span class="p">],</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_nb_nearest_neighbours</span><span class="p">,</span>
+            <span class="n">graph</span><span class="o">.</span><span class="n">batch</span><span class="p">,</span>
+        <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">graph</span></div>
+
+
+
+<div class="viewcode-block" id="RadialEdges">
+<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.RadialEdges">[docs]</a>
+<span class="k">class</span> <span class="nc">RadialEdges</span><span class="p">(</span><span class="n">EdgeDefinition</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Builds graph from a sphere of chosen radius centred at each node."""</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">radius</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
+        <span class="n">columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Radial edges.</span>
+
+<span class="sd">        Connects each node to other nodes that are within a sphere of</span>
+<span class="sd">        radius ´r´ centered at the node. The feature space of ´r´ is defined</span>
+<span class="sd">        by ´columns´</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            radius: radius of sphere</span>
+<span class="sd">            columns: columns of the node feature matrix used.</span>
+<span class="sd">            Defaults to [0,1,2].</span>
+<span class="sd">        """</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>
+
+        <span class="c1"># Member variable(s)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_radius</span> <span class="o">=</span> <span class="n">radius</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_columns</span> <span class="o">=</span> <span class="n">columns</span>
+
+    <span class="k">def</span> <span class="nf">_construct_edges</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Define radial edges."""</span>
+        <span class="n">graph</span><span class="o">.</span><span class="n">edge_index</span> <span class="o">=</span> <span class="n">radius_graph</span><span class="p">(</span>
+            <span class="n">graph</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_columns</span><span class="p">],</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_radius</span><span class="p">,</span>
+            <span class="n">graph</span><span class="o">.</span><span class="n">batch</span><span class="p">,</span>
+        <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">graph</span></div>
+
+
+
+<div class="viewcode-block" id="EuclideanEdges">
+<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EuclideanEdges">[docs]</a>
+<span class="k">class</span> <span class="nc">EuclideanEdges</span><span class="p">(</span><span class="n">EdgeDefinition</span><span class="p">):</span>  <span class="c1"># pylint: disable=too-few-public-methods</span>
+<span class="w">    </span><span class="sd">"""Builds edges according to Euclidean distance between nodes.</span>
+
+<span class="sd">    See https://arxiv.org/pdf/1809.06166.pdf.</span>
+<span class="sd">    """</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">sigma</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
+        <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
+        <span class="n">columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct `EuclideanEdges`."""</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>
+
+        <span class="c1"># Check(s)</span>
+        <span class="k">if</span> <span class="n">columns</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">columns</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span>
+
+        <span class="c1"># Member variable(s)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_sigma</span> <span class="o">=</span> <span class="n">sigma</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_threshold</span> <span class="o">=</span> <span class="n">threshold</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_columns</span> <span class="o">=</span> <span class="n">columns</span>
+
+    <span class="k">def</span> <span class="nf">_construct_edges</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Forward pass."""</span>
+        <span class="c1"># Constructs the adjacency matrix from the raw, DOM-level data and</span>
+        <span class="c1"># returns this matrix</span>
+        <span class="k">if</span> <span class="n">graph</span><span class="o">.</span><span class="n">edge_index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
+                <span class="s2">"WARNING: GraphBuilder received graph with pre-existing "</span>
+                <span class="s2">"structure. Will overwrite."</span>
+            <span class="p">)</span>
+
+        <span class="n">xyz_coords</span> <span class="o">=</span> <span class="n">graph</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_columns</span><span class="p">]</span>
+
+        <span class="c1"># Construct block-diagonal matrix indicating whether pulses belong to</span>
+        <span class="c1"># the same event in the batch</span>
+        <span class="n">batch_mask</span> <span class="o">=</span> <span class="n">graph</span><span class="o">.</span><span class="n">batch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="n">graph</span><span class="o">.</span><span class="n">batch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span>
+            <span class="n">dim</span><span class="o">=</span><span class="mi">1</span>
+        <span class="p">)</span>
+
+        <span class="n">distance_matrix</span> <span class="o">=</span> <span class="n">calculate_distance_matrix</span><span class="p">(</span><span class="n">xyz_coords</span><span class="p">)</span>
+        <span class="n">affinity_matrix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span>
+            <span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">distance_matrix</span><span class="o">**</span><span class="mi">2</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sigma</span><span class="o">**</span><span class="mi">2</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Use softmax to normalise all adjacencies to one for each node</span>
+        <span class="n">exp_row_sums</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">affinity_matrix</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+        <span class="n">weighted_adj_matrix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span>
+            <span class="n">affinity_matrix</span>
+        <span class="p">)</span> <span class="o">/</span> <span class="n">exp_row_sums</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+
+        <span class="c1"># Only include edges with weights that exceed the chosen threshold (and</span>
+        <span class="c1"># are part of the same event)</span>
+        <span class="n">sources</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span>
+            <span class="p">(</span><span class="n">weighted_adj_matrix</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">_threshold</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">batch_mask</span><span class="p">)</span>
+        <span class="p">)</span>
+        <span class="n">edge_weights</span> <span class="o">=</span> <span class="n">weighted_adj_matrix</span><span class="p">[</span><span class="n">sources</span><span class="p">,</span> <span class="n">targets</span><span class="p">]</span>
+
+        <span class="n">graph</span><span class="o">.</span><span class="n">edge_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">sources</span><span class="p">,</span> <span class="n">targets</span><span class="p">))</span>
+        <span class="n">graph</span><span class="o">.</span><span class="n">edge_weight</span> <span class="o">=</span> <span class="n">edge_weights</span>
+
+        <span class="k">return</span> <span class="n">graph</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/graphs/graph_definition.html b/_modules/graphnet/models/graphs/graph_definition.html
new file mode 100644
index 000000000..75075b762
--- /dev/null
+++ b/_modules/graphnet/models/graphs/graph_definition.html
@@ -0,0 +1,634 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.graphs.graph_definition &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/graphs/graph_definition" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.graphs.graph_definition </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-graphs-graph-definition--page-root">Source code for graphnet.models.graphs.graph_definition</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Modules for defining graphs.</span>
+
+<span class="sd">These are self-contained graph definitions that hold all the graph-altering</span>
+<span class="sd">code in graphnet. These modules define what the GNNs sees as input and can be</span>
+<span class="sd">passed to dataloaders during training and deployment.</span>
+<span class="sd">"""</span>
+
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Callable</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.models.detector</span> <span class="kn">import</span> <span class="n">Detector</span>
+<span class="kn">from</span> <span class="nn">.edges</span> <span class="kn">import</span> <span class="n">EdgeDefinition</span>
+<span class="kn">from</span> <span class="nn">.nodes</span> <span class="kn">import</span> <span class="n">NodeDefinition</span>
+<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span>
+
+
+<div class="viewcode-block" id="GraphDefinition">
+<a class="viewcode-back" href="../../../../api/graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition">[docs]</a>
+<span class="k">class</span> <span class="nc">GraphDefinition</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""An Abstract class to create graph definitions from."""</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">detector</span><span class="p">:</span> <span class="n">Detector</span><span class="p">,</span>
+        <span class="n">node_definition</span><span class="p">:</span> <span class="n">NodeDefinition</span><span class="p">,</span>
+        <span class="n">edge_definition</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">EdgeDefinition</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct ´GraphDefinition´. The ´detector´ holds.</span>
+
+<span class="sd">        ´Detector´-specific code. E.g. scaling/standardization and geometry</span>
+<span class="sd">        tables.</span>
+
+<span class="sd">        ´node_definition´ defines the nodes in the graph.</span>
+
+<span class="sd">        ´edge_definition´ defines the connectivity of the nodes in the graph.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            detector: The corresponding ´Detector´ representing the data.</span>
+<span class="sd">            node_definition: Definition of nodes.</span>
+<span class="sd">            edge_definition: Definition of edges. Defaults to None.</span>
+<span class="sd">            node_feature_names: Names of node feature columns. Defaults to None</span>
+<span class="sd">            dtype: data type used for node features. e.g. ´torch.float´</span>
+<span class="sd">        """</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>
+
+        <span class="c1"># Member Variables</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_detector</span> <span class="o">=</span> <span class="n">detector</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_edge_definition</span> <span class="o">=</span> <span class="n">edge_definition</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_node_definition</span> <span class="o">=</span> <span class="n">node_definition</span>
+        <span class="k">if</span> <span class="n">node_feature_names</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="c1"># Assume all features in Detector is used.</span>
+            <span class="n">node_feature_names</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_detector</span><span class="o">.</span><span class="n">feature_map</span><span class="p">()</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>  <span class="c1"># type: ignore</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_node_feature_names</span> <span class="o">=</span> <span class="n">node_feature_names</span>
+
+        <span class="c1"># Set data type</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
+
+        <span class="c1"># Set Input / Output dimensions</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_node_definition</span><span class="o">.</span><span class="n">set_number_of_inputs</span><span class="p">(</span>
+            <span class="n">node_feature_names</span><span class="o">=</span><span class="n">node_feature_names</span>
+        <span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">nb_inputs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_node_feature_names</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">nb_outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_node_definition</span><span class="o">.</span><span class="n">nb_outputs</span>
+
+<div class="viewcode-block" id="GraphDefinition.forward">
+<a class="viewcode-back" href="../../../../api/graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition.forward">[docs]</a>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span>  <span class="c1"># type: ignore</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">node_features</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
+        <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+        <span class="n">truth_dicts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">custom_label_functions</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">loss_weight_column</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">loss_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">loss_weight_default_value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">data_path</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct graph as ´Data´ object.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            node_features: node features for graph. Shape ´[num_nodes, d]´</span>
+<span class="sd">            node_feature_names: name of each column. Shape ´[,d]´.</span>
+<span class="sd">            truth_dicts: Dictionary containing truth labels.</span>
+<span class="sd">            custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels.</span>
+<span class="sd">            loss_weight_column: Name of column that holds loss weight. Defaults to None.</span>
+<span class="sd">            loss_weight: Loss weight associated with event. Defaults to None.</span>
+<span class="sd">            loss_weight_default_value: default value for loss weight. Used in instances where some events have no pre-defined loss weight. Defaults to None.</span>
+<span class="sd">            data_path: Path to dataset data files. Defaults to None.</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            graph</span>
+<span class="sd">        """</span>
+        <span class="c1"># Checks</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_validate_input</span><span class="p">(</span>
+            <span class="n">node_features</span><span class="o">=</span><span class="n">node_features</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="o">=</span><span class="n">node_feature_names</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Transform to pytorch tensor</span>
+        <span class="n">node_features</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">node_features</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
+
+        <span class="c1"># Standardize / Scale  node features</span>
+        <span class="n">node_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_detector</span><span class="p">(</span><span class="n">node_features</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="p">)</span>
+
+        <span class="c1"># Create graph</span>
+        <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_node_definition</span><span class="p">(</span><span class="n">node_features</span><span class="p">)</span>
+
+        <span class="c1"># Attach number of pulses as static attribute.</span>
+        <span class="n">graph</span><span class="o">.</span><span class="n">n_pulses</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">node_features</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
+
+        <span class="c1"># Assign edges</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_edge_definition</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_edge_definition</span><span class="p">(</span><span class="n">graph</span><span class="p">)</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">warnonce</span><span class="p">(</span>
+                <span class="s2">"No EdgeDefinition provided. Graphs will not have edges defined!"</span>
+            <span class="p">)</span>
+
+        <span class="c1"># Attach data path - useful for Ensemble datasets.</span>
+        <span class="k">if</span> <span class="n">data_path</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">graph</span><span class="p">[</span><span class="s2">"dataset_path"</span><span class="p">]</span> <span class="o">=</span> <span class="n">data_path</span>
+
+        <span class="c1"># Attach loss weights if they exist</span>
+        <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_loss_weights</span><span class="p">(</span>
+            <span class="n">graph</span><span class="o">=</span><span class="n">graph</span><span class="p">,</span>
+            <span class="n">loss_weight</span><span class="o">=</span><span class="n">loss_weight</span><span class="p">,</span>
+            <span class="n">loss_weight_column</span><span class="o">=</span><span class="n">loss_weight_column</span><span class="p">,</span>
+            <span class="n">loss_weight_default_value</span><span class="o">=</span><span class="n">loss_weight_default_value</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Attach default truth labels and node truths</span>
+        <span class="k">if</span> <span class="n">truth_dicts</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_truth</span><span class="p">(</span><span class="n">graph</span><span class="o">=</span><span class="n">graph</span><span class="p">,</span> <span class="n">truth_dicts</span><span class="o">=</span><span class="n">truth_dicts</span><span class="p">)</span>
+
+        <span class="c1"># Attach custom truth labels</span>
+        <span class="k">if</span> <span class="n">custom_label_functions</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_custom_labels</span><span class="p">(</span>
+                <span class="n">graph</span><span class="o">=</span><span class="n">graph</span><span class="p">,</span> <span class="n">custom_label_functions</span><span class="o">=</span><span class="n">custom_label_functions</span>
+            <span class="p">)</span>
+
+        <span class="c1"># Attach node features as seperate fields. MAY NOT CONTAIN 'x'</span>
+        <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_features_individually</span><span class="p">(</span>
+            <span class="n">graph</span><span class="o">=</span><span class="n">graph</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="o">=</span><span class="n">node_feature_names</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Add GraphDefinition Stamp</span>
+        <span class="n">graph</span><span class="p">[</span><span class="s2">"graph_definition"</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
+        <span class="k">return</span> <span class="n">graph</span></div>
+
+
+    <span class="k">def</span> <span class="nf">_validate_input</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">node_features</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="c1"># node feature matrix dimension check</span>
+        <span class="k">assert</span> <span class="n">node_features</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">node_feature_names</span><span class="p">)</span>
+
+        <span class="c1"># check that provided features for input is the same that the ´Graph´</span>
+        <span class="c1"># was instantiated with.</span>
+        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">node_feature_names</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_node_feature_names</span>
+        <span class="p">),</span> <span class="sa">f</span><span class="s2">"""Input features (</span><span class="si">{</span><span class="n">node_feature_names</span><span class="si">}</span><span class="s2">) is not what </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> was instatiated with (</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_node_feature_names</span><span class="si">}</span><span class="s2">)"""</span>
+        <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">node_feature_names</span><span class="p">)):</span>
+            <span class="k">assert</span> <span class="p">(</span>
+                <span class="n">node_feature_names</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">_node_feature_names</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
+            <span class="p">),</span> <span class="sa">f</span><span class="s2">""" Order of node features in data are not the same as expected. Got </span><span class="si">{</span><span class="n">node_feature_names</span><span class="si">}</span><span class="s2"> vs. </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_node_feature_names</span><span class="si">}</span><span class="s2">"""</span>
+
+    <span class="k">def</span> <span class="nf">_add_loss_weights</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span>
+        <span class="n">loss_weight_column</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">loss_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">loss_weight_default_value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Attempt to store a loss weight in the graph for use during training.</span>
+
+<span class="sd">        I.e. `graph[loss_weight_column] = loss_weight`</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            loss_weight: The non-negative weight to be stored.</span>
+<span class="sd">            graph: Data object representing the event.</span>
+<span class="sd">            loss_weight_column: The name under which the weight is stored in</span>
+<span class="sd">                                 the graph.</span>
+<span class="sd">            loss_weight_default_value: The default value used if</span>
+<span class="sd">                                        none was retrieved.</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            A graph with loss weight added, if available.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Add loss weight to graph.</span>
+        <span class="k">if</span> <span class="n">loss_weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">loss_weight_column</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="c1"># No loss weight was retrieved, i.e., it is missing for the current</span>
+            <span class="c1"># event.</span>
+            <span class="k">if</span> <span class="n">loss_weight</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
+                <span class="k">if</span> <span class="n">loss_weight_default_value</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+                    <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
+                        <span class="s2">"At least one event is missing an entry in "</span>
+                        <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">loss_weight_column</span><span class="si">}</span><span class="s2"> "</span>
+                        <span class="s2">"but loss_weight_default_value is None."</span>
+                    <span class="p">)</span>
+                <span class="n">graph</span><span class="p">[</span><span class="n">loss_weight_column</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
+                    <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_default_value</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span>
+                <span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="n">graph</span><span class="p">[</span><span class="n">loss_weight_column</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
+                    <span class="n">loss_weight</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span>
+                <span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">graph</span>
+
+    <span class="k">def</span> <span class="nf">_add_truth</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">truth_dicts</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Add truth labels from ´truth_dicts´ to ´graph´.</span>
+
+<span class="sd">        I.e. ´graph[key] = truth_dict[key]´</span>
+
+
+<span class="sd">        Args:</span>
+<span class="sd">            graph: graph where the label will be stored</span>
+<span class="sd">            truth_dicts: dictionary containing the labels</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            graph with labels</span>
+<span class="sd">        """</span>
+        <span class="c1"># Write attributes, either target labels, truth info or original</span>
+        <span class="c1"># features.</span>
+        <span class="k">for</span> <span class="n">truth_dict</span> <span class="ow">in</span> <span class="n">truth_dicts</span><span class="p">:</span>
+            <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">truth_dict</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
+                <span class="k">try</span><span class="p">:</span>
+                    <span class="n">graph</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
+                <span class="k">except</span> <span class="ne">TypeError</span><span class="p">:</span>
+                    <span class="c1"># Cannot convert `value` to Tensor due to its data type,</span>
+                    <span class="c1"># e.g. `str`.</span>
+                    <span class="bp">self</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
+                        <span class="p">(</span>
+                            <span class="sa">f</span><span class="s2">"Could not assign `</span><span class="si">{</span><span class="n">key</span><span class="si">}</span><span class="s2">` with type "</span>
+                            <span class="sa">f</span><span class="s2">"'</span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">value</span><span class="p">)</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">' as attribute to graph."</span>
+                        <span class="p">)</span>
+                    <span class="p">)</span>
+        <span class="k">return</span> <span class="n">graph</span>
+
+    <span class="k">def</span> <span class="nf">_add_features_individually</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span>
+        <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+        <span class="c1"># Additionally add original features as (static) attributes</span>
+        <span class="n">graph</span><span class="o">.</span><span class="n">features</span> <span class="o">=</span> <span class="n">node_feature_names</span>
+        <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">feature</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">node_feature_names</span><span class="p">):</span>
+            <span class="k">if</span> <span class="n">feature</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">"x"</span><span class="p">]:</span>  <span class="c1"># reserved for node features.</span>
+                <span class="n">graph</span><span class="p">[</span><span class="n">feature</span><span class="p">]</span> <span class="o">=</span> <span class="n">graph</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="n">index</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">warnonce</span><span class="p">(</span>
+<span class="w">                    </span><span class="sd">"""Cannot assign graph['x']. This field is reserved for node features. Please rename your input feature."""</span>
+                <span class="p">)</span>
+        <span class="k">return</span> <span class="n">graph</span>
+
+    <span class="k">def</span> <span class="nf">_add_custom_labels</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span>
+        <span class="n">custom_label_functions</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">Any</span><span class="p">]],</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+        <span class="c1"># Add custom labels to the graph</span>
+        <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">fn</span> <span class="ow">in</span> <span class="n">custom_label_functions</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
+            <span class="n">graph</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">fn</span><span class="p">(</span><span class="n">graph</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">graph</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/graphs/graphs.html b/_modules/graphnet/models/graphs/graphs.html
new file mode 100644
index 000000000..5f3dc4be7
--- /dev/null
+++ b/_modules/graphnet/models/graphs/graphs.html
@@ -0,0 +1,410 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.graphs.graphs &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/graphs/graphs" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.graphs.graphs </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-graphs-graphs--page-root">Source code for graphnet.models.graphs.graphs</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""A module containing different graph representations in GraphNeT."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span>
+<span class="kn">from</span> <span class="nn">.graph_definition</span> <span class="kn">import</span> <span class="n">GraphDefinition</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.detector</span> <span class="kn">import</span> <span class="n">Detector</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.graphs.edges</span> <span class="kn">import</span> <span class="n">EdgeDefinition</span><span class="p">,</span> <span class="n">KNNEdges</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.graphs.nodes</span> <span class="kn">import</span> <span class="n">NodeDefinition</span>
+
+
+<div class="viewcode-block" id="KNNGraph">
+<a class="viewcode-back" href="../../../../api/graphnet.models.graphs.graphs.html#graphnet.models.graphs.graphs.KNNGraph">[docs]</a>
+<span class="k">class</span> <span class="nc">KNNGraph</span><span class="p">(</span><span class="n">GraphDefinition</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""A Graph representation where Edges are drawn to nearest neighbours."""</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">detector</span><span class="p">:</span> <span class="n">Detector</span><span class="p">,</span>
+        <span class="n">node_definition</span><span class="p">:</span> <span class="n">NodeDefinition</span><span class="p">,</span>
+        <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
+        <span class="n">nb_nearest_neighbours</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span>
+        <span class="n">columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct k-nn graph representation.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            detector: Detector that represents your data.</span>
+<span class="sd">            node_definition: Definition of nodes in the graph.</span>
+<span class="sd">            node_feature_names: Name of node features.</span>
+<span class="sd">            dtype: data type for node features.</span>
+<span class="sd">            nb_nearest_neighbours: Number of edges for each node. Defaults to 8.</span>
+<span class="sd">            columns: node feature columns used for distance calculation</span>
+<span class="sd">            . Defaults to [0, 1, 2].</span>
+<span class="sd">        """</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
+            <span class="n">detector</span><span class="o">=</span><span class="n">detector</span><span class="p">,</span>
+            <span class="n">node_definition</span><span class="o">=</span><span class="n">node_definition</span><span class="p">,</span>
+            <span class="n">edge_definition</span><span class="o">=</span><span class="n">KNNEdges</span><span class="p">(</span>
+                <span class="n">nb_nearest_neighbours</span><span class="o">=</span><span class="n">nb_nearest_neighbours</span><span class="p">,</span>
+                <span class="n">columns</span><span class="o">=</span><span class="n">columns</span><span class="p">,</span>
+            <span class="p">),</span>
+            <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
+            <span class="n">node_feature_names</span><span class="o">=</span><span class="n">node_feature_names</span><span class="p">,</span>
+        <span class="p">)</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/graphs/nodes/nodes.html b/_modules/graphnet/models/graphs/nodes/nodes.html
new file mode 100644
index 000000000..3e7040189
--- /dev/null
+++ b/_modules/graphnet/models/graphs/nodes/nodes.html
@@ -0,0 +1,444 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.graphs.nodes.nodes &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/graphs/nodes/nodes" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.graphs.nodes.nodes </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../../"versions.json"",
+        target_loc = "../../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-graphs-nodes-nodes--page-root">Source code for graphnet.models.graphs.nodes.nodes</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Class(es) for building/connecting graphs."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
+<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span>
+
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.decorators</span> <span class="kn">import</span> <span class="n">final</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span>
+<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span>
+
+
+<div class="viewcode-block" id="NodeDefinition">
+<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition">[docs]</a>
+<span class="k">class</span> <span class="nc">NodeDefinition</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span>  <span class="c1"># pylint: disable=too-few-public-methods</span>
+<span class="w">    </span><span class="sd">"""Base class for graph building."""</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct `Detector`."""</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>
+
+<div class="viewcode-block" id="NodeDefinition.forward">
+<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition.forward">[docs]</a>
+    <span class="nd">@final</span>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct nodes from raw node features.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            x: standardized node features with shape ´[num_pulses, d]´,</span>
+<span class="sd">            where ´d´ is the number of node features.</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            graph: a graph without edges</span>
+<span class="sd">        """</span>
+        <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_construct_nodes</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">graph</span></div>
+
+
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">nb_outputs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Return number of output features.</span>
+
+<span class="sd">        This the default, but may be overridden by specific inheriting classes.</span>
+<span class="sd">        """</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_inputs</span>
+
+<div class="viewcode-block" id="NodeDefinition.set_number_of_inputs">
+<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs">[docs]</a>
+    <span class="nd">@final</span>
+    <span class="k">def</span> <span class="nf">set_number_of_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Return number of inputs expected by node definition.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            node_feature_names: name of each node feature column.</span>
+<span class="sd">        """</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">node_feature_names</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">nb_inputs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">node_feature_names</span><span class="p">)</span></div>
+
+
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">_construct_nodes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct nodes from raw node features ´x´.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            x: standardized node features with shape ´[num_pulses, d]´,</span>
+<span class="sd">            where ´d´ is the number of node features.</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            graph: graph without edges.</span>
+<span class="sd">        """</span></div>
+
+
+
+<div class="viewcode-block" id="NodesAsPulses">
+<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodesAsPulses">[docs]</a>
+<span class="k">class</span> <span class="nc">NodesAsPulses</span><span class="p">(</span><span class="n">NodeDefinition</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Represent each measured pulse of Cherenkov Radiation as a node."""</span>
+
+    <span class="k">def</span> <span class="nf">_construct_nodes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Data</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">Data</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">)</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/model.html b/_modules/graphnet/models/model.html
new file mode 100644
index 000000000..9654c1e81
--- /dev/null
+++ b/_modules/graphnet/models/model.html
@@ -0,0 +1,726 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.model &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" />
+    <script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../about.html" />
+    <link rel="index" title="Index" href="../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/model" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.model </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../"versions.json"",
+        target_loc = "../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-model--page-root">Source code for graphnet.models.model</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Base class(es) for building models."""</span>
+
+<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span>
+<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">OrderedDict</span>
+<span class="kn">import</span> <span class="nn">dill</span>
+<span class="kn">import</span> <span class="nn">os.path</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span>
+
+<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
+<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span>
+<span class="kn">from</span> <span class="nn">pytorch_lightning</span> <span class="kn">import</span> <span class="n">Trainer</span><span class="p">,</span> <span class="n">LightningModule</span>
+<span class="kn">from</span> <span class="nn">pytorch_lightning.callbacks.callback</span> <span class="kn">import</span> <span class="n">Callback</span>
+<span class="kn">from</span> <span class="nn">pytorch_lightning.callbacks</span> <span class="kn">import</span> <span class="n">EarlyStopping</span>
+<span class="kn">from</span> <span class="nn">pytorch_lightning.loggers.logger</span> <span class="kn">import</span> <span class="n">Logger</span> <span class="k">as</span> <span class="n">LightningLogger</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span>
+<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">SequentialSampler</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">Configurable</span><span class="p">,</span> <span class="n">ModelConfig</span>
+<span class="kn">from</span> <span class="nn">graphnet.training.callbacks</span> <span class="kn">import</span> <span class="n">ProgressBar</span>
+
+
+<div class="viewcode-block" id="Model">
+<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model">[docs]</a>
+<span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">Logger</span><span class="p">,</span> <span class="n">Configurable</span><span class="p">,</span> <span class="n">LightningModule</span><span class="p">,</span> <span class="n">ABC</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Base class for all models in graphnet."""</span>
+
+<div class="viewcode-block" id="Model.forward">
+<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.forward">[docs]</a>
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Forward pass."""</span></div>
+
+
+    <span class="nd">@staticmethod</span>
+    <span class="k">def</span> <span class="nf">_construct_trainer</span><span class="p">(</span>
+        <span class="n">max_epochs</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
+        <span class="n">gpus</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">callbacks</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Callback</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">ckpt_path</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">logger</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LightningLogger</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">log_every_n_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
+        <span class="n">gradient_clip_val</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">distribution_strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"ddp"</span><span class="p">,</span>
+        <span class="o">**</span><span class="n">trainer_kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Trainer</span><span class="p">:</span>
+
+        <span class="k">if</span> <span class="n">gpus</span><span class="p">:</span>
+            <span class="n">accelerator</span> <span class="o">=</span> <span class="s2">"gpu"</span>
+            <span class="n">devices</span> <span class="o">=</span> <span class="n">gpus</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">accelerator</span> <span class="o">=</span> <span class="s2">"cpu"</span>
+            <span class="n">devices</span> <span class="o">=</span> <span class="mi">1</span>
+
+        <span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span>
+            <span class="n">accelerator</span><span class="o">=</span><span class="n">accelerator</span><span class="p">,</span>
+            <span class="n">devices</span><span class="o">=</span><span class="n">devices</span><span class="p">,</span>
+            <span class="n">max_epochs</span><span class="o">=</span><span class="n">max_epochs</span><span class="p">,</span>
+            <span class="n">callbacks</span><span class="o">=</span><span class="n">callbacks</span><span class="p">,</span>
+            <span class="n">log_every_n_steps</span><span class="o">=</span><span class="n">log_every_n_steps</span><span class="p">,</span>
+            <span class="n">logger</span><span class="o">=</span><span class="n">logger</span><span class="p">,</span>
+            <span class="n">gradient_clip_val</span><span class="o">=</span><span class="n">gradient_clip_val</span><span class="p">,</span>
+            <span class="n">strategy</span><span class="o">=</span><span class="n">distribution_strategy</span><span class="p">,</span>
+            <span class="n">default_root_dir</span><span class="o">=</span><span class="n">ckpt_path</span><span class="p">,</span>
+            <span class="o">**</span><span class="n">trainer_kwargs</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">trainer</span>
+
+<div class="viewcode-block" id="Model.fit">
+<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.fit">[docs]</a>
+    <span class="k">def</span> <span class="nf">fit</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">train_dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span>
+        <span class="n">val_dataloader</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">DataLoader</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="o">*</span><span class="p">,</span>
+        <span class="n">max_epochs</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
+        <span class="n">gpus</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">callbacks</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Callback</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">ckpt_path</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">logger</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LightningLogger</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">log_every_n_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
+        <span class="n">gradient_clip_val</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">distribution_strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"ddp"</span><span class="p">,</span>
+        <span class="o">**</span><span class="n">trainer_kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Fit `Model` using `pytorch_lightning.Trainer`."""</span>
+        <span class="c1"># Checks</span>
+        <span class="k">if</span> <span class="n">callbacks</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">callbacks</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_default_callbacks</span><span class="p">(</span>
+                <span class="n">val_dataloader</span><span class="o">=</span><span class="n">val_dataloader</span><span class="p">,</span>
+            <span class="p">)</span>
+        <span class="k">elif</span> <span class="n">val_dataloader</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">callbacks</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_early_stopping</span><span class="p">(</span>
+                <span class="n">val_dataloader</span><span class="o">=</span><span class="n">val_dataloader</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="n">callbacks</span>
+            <span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+        <span class="n">trainer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_construct_trainer</span><span class="p">(</span>
+            <span class="n">max_epochs</span><span class="o">=</span><span class="n">max_epochs</span><span class="p">,</span>
+            <span class="n">gpus</span><span class="o">=</span><span class="n">gpus</span><span class="p">,</span>
+            <span class="n">callbacks</span><span class="o">=</span><span class="n">callbacks</span><span class="p">,</span>
+            <span class="n">ckpt_path</span><span class="o">=</span><span class="n">ckpt_path</span><span class="p">,</span>
+            <span class="n">logger</span><span class="o">=</span><span class="n">logger</span><span class="p">,</span>
+            <span class="n">log_every_n_steps</span><span class="o">=</span><span class="n">log_every_n_steps</span><span class="p">,</span>
+            <span class="n">gradient_clip_val</span><span class="o">=</span><span class="n">gradient_clip_val</span><span class="p">,</span>
+            <span class="n">distribution_strategy</span><span class="o">=</span><span class="n">distribution_strategy</span><span class="p">,</span>
+            <span class="o">**</span><span class="n">trainer_kwargs</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="k">try</span><span class="p">:</span>
+            <span class="n">trainer</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span>
+                <span class="bp">self</span><span class="p">,</span> <span class="n">train_dataloader</span><span class="p">,</span> <span class="n">val_dataloader</span><span class="p">,</span> <span class="n">ckpt_path</span><span class="o">=</span><span class="n">ckpt_path</span>
+            <span class="p">)</span>
+        <span class="k">except</span> <span class="ne">KeyboardInterrupt</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"[ctrl+c] Exiting gracefully."</span><span class="p">)</span>
+            <span class="k">pass</span></div>
+
+
+    <span class="k">def</span> <span class="nf">_create_default_callbacks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val_dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">:</span>
+        <span class="n">callbacks</span> <span class="o">=</span> <span class="p">[</span><span class="n">ProgressBar</span><span class="p">()]</span>
+        <span class="n">callbacks</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_early_stopping</span><span class="p">(</span>
+            <span class="n">val_dataloader</span><span class="o">=</span><span class="n">val_dataloader</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="n">callbacks</span>
+        <span class="p">)</span>
+        <span class="k">return</span> <span class="n">callbacks</span>
+
+    <span class="k">def</span> <span class="nf">_add_early_stopping</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">val_dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">callbacks</span><span class="p">:</span> <span class="n">List</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">:</span>
+        <span class="k">if</span> <span class="n">val_dataloader</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="k">return</span> <span class="n">callbacks</span>
+        <span class="n">has_early_stopping</span> <span class="o">=</span> <span class="kc">False</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">callbacks</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
+        <span class="k">for</span> <span class="n">callback</span> <span class="ow">in</span> <span class="n">callbacks</span><span class="p">:</span>
+            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">callback</span><span class="p">,</span> <span class="n">EarlyStopping</span><span class="p">):</span>
+                <span class="n">has_early_stopping</span> <span class="o">=</span> <span class="kc">True</span>
+
+        <span class="k">if</span> <span class="ow">not</span> <span class="n">has_early_stopping</span><span class="p">:</span>
+            <span class="n">callbacks</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
+                <span class="n">EarlyStopping</span><span class="p">(</span>
+                    <span class="n">monitor</span><span class="o">=</span><span class="s2">"val_loss"</span><span class="p">,</span>
+                    <span class="n">patience</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
+                <span class="p">)</span>
+            <span class="p">)</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">warning_once</span><span class="p">(</span>
+                <span class="s2">"Got validation dataloader but no EarlyStopping callback. An "</span>
+                <span class="s2">"EarlyStopping callback has been added automatically with "</span>
+                <span class="s2">"patience=5 and monitor = 'val_loss'."</span>
+            <span class="p">)</span>
+        <span class="k">return</span> <span class="n">callbacks</span>
+
+<div class="viewcode-block" id="Model.predict">
+<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.predict">[docs]</a>
+    <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span>
+        <span class="n">gpus</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">distribution_strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"auto"</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Return predictions for `dataloader`.</span>
+
+<span class="sd">        Returns a list of Tensors, one for each model output.</span>
+<span class="sd">        """</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
+
+        <span class="n">callbacks</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_default_callbacks</span><span class="p">(</span>
+            <span class="n">val_dataloader</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="n">inference_trainer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_construct_trainer</span><span class="p">(</span>
+            <span class="n">gpus</span><span class="o">=</span><span class="n">gpus</span><span class="p">,</span>
+            <span class="n">distribution_strategy</span><span class="o">=</span><span class="n">distribution_strategy</span><span class="p">,</span>
+            <span class="n">callbacks</span><span class="o">=</span><span class="n">callbacks</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="n">predictions_list</span> <span class="o">=</span> <span class="n">inference_trainer</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataloader</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions_list</span><span class="p">),</span> <span class="s2">"Got no predictions"</span>
+
+        <span class="n">nb_outputs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions_list</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
+        <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span>
+            <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">preds</span><span class="p">[</span><span class="n">ix</span><span class="p">]</span> <span class="k">for</span> <span class="n">preds</span> <span class="ow">in</span> <span class="n">predictions_list</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
+            <span class="k">for</span> <span class="n">ix</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nb_outputs</span><span class="p">)</span>
+        <span class="p">]</span>
+
+        <span class="k">return</span> <span class="n">predictions</span></div>
+
+
+<div class="viewcode-block" id="Model.predict_as_dataframe">
+<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.predict_as_dataframe">[docs]</a>
+    <span class="k">def</span> <span class="nf">predict_as_dataframe</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span>
+        <span class="n">prediction_columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+        <span class="o">*</span><span class="p">,</span>
+        <span class="n">additional_attributes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">gpus</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">distribution_strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"auto"</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Return predictions for `dataloader` as a DataFrame.</span>
+
+<span class="sd">        Include `additional_attributes` as additional columns in the output</span>
+<span class="sd">        DataFrame.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Check(s)</span>
+        <span class="k">if</span> <span class="n">additional_attributes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">additional_attributes</span> <span class="o">=</span> <span class="p">[]</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">additional_attributes</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
+
+        <span class="k">if</span> <span class="p">(</span>
+            <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataloader</span><span class="o">.</span><span class="n">sampler</span><span class="p">,</span> <span class="n">SequentialSampler</span><span class="p">)</span>
+            <span class="ow">and</span> <span class="n">additional_attributes</span>
+        <span class="p">):</span>
+            <span class="nb">print</span><span class="p">(</span><span class="n">dataloader</span><span class="o">.</span><span class="n">sampler</span><span class="p">)</span>
+            <span class="k">raise</span> <span class="ne">UserWarning</span><span class="p">(</span>
+                <span class="s2">"DataLoader has a `sampler` that is not `SequentialSampler`, "</span>
+                <span class="s2">"indicating that shuffling is enabled. Using "</span>
+                <span class="s2">"`predict_as_dataframe` with `additional_attributes` assumes "</span>
+                <span class="s2">"that the sequence of batches in `dataloader` are "</span>
+                <span class="s2">"deterministic. Either call this method a `dataloader` which "</span>
+                <span class="s2">"doesn't resample batches; or do not request "</span>
+                <span class="s2">"`additional_attributes`."</span>
+            <span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Column names for predictions are: </span><span class="se">\n</span><span class="s2"> </span><span class="si">{</span><span class="n">prediction_columns</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
+        <span class="n">predictions_torch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
+            <span class="n">dataloader</span><span class="o">=</span><span class="n">dataloader</span><span class="p">,</span>
+            <span class="n">gpus</span><span class="o">=</span><span class="n">gpus</span><span class="p">,</span>
+            <span class="n">distribution_strategy</span><span class="o">=</span><span class="n">distribution_strategy</span><span class="p">,</span>
+        <span class="p">)</span>
+        <span class="n">predictions</span> <span class="o">=</span> <span class="p">(</span>
+            <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">predictions_torch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
+        <span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">prediction_columns</span><span class="p">)</span> <span class="o">==</span> <span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="p">(</span>
+            <span class="sa">f</span><span class="s2">"Number of provided column names (</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">prediction_columns</span><span class="p">)</span><span class="si">}</span><span class="s2">) and "</span>
+            <span class="sa">f</span><span class="s2">"number of output columns (</span><span class="si">{</span><span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2">) don't match."</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Get additional attributes</span>
+        <span class="n">attributes</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]]</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">(</span>
+            <span class="p">[(</span><span class="n">attr</span><span class="p">,</span> <span class="p">[])</span> <span class="k">for</span> <span class="n">attr</span> <span class="ow">in</span> <span class="n">additional_attributes</span><span class="p">]</span>
+        <span class="p">)</span>
+
+        <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">dataloader</span><span class="p">:</span>
+            <span class="k">for</span> <span class="n">attr</span> <span class="ow">in</span> <span class="n">attributes</span><span class="p">:</span>
+                <span class="n">attribute</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">attr</span><span class="p">]</span>
+                <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">attribute</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
+                    <span class="n">attribute</span> <span class="o">=</span> <span class="n">attribute</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
+
+                <span class="c1"># Check if node level predictions</span>
+                <span class="c1"># If true, additional attributes are repeated</span>
+                <span class="c1"># to make dimensions fit</span>
+                <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">dataloader</span><span class="o">.</span><span class="n">dataset</span><span class="p">):</span>
+                    <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">attribute</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span>
+                        <span class="n">batch</span><span class="o">.</span><span class="n">n_pulses</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
+                    <span class="p">):</span>
+                        <span class="n">attribute</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span>
+                            <span class="n">attribute</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_pulses</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
+                        <span class="p">)</span>
+                        <span class="k">try</span><span class="p">:</span>
+                            <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">attribute</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">x</span><span class="p">)</span>
+                        <span class="k">except</span> <span class="ne">AssertionError</span><span class="p">:</span>
+                            <span class="bp">self</span><span class="o">.</span><span class="n">warning_once</span><span class="p">(</span>
+                                <span class="s2">"Could not automatically adjust length"</span>
+                                <span class="sa">f</span><span class="s2">"of additional attribute </span><span class="si">{</span><span class="n">attr</span><span class="si">}</span><span class="s2"> to match length of"</span>
+                                <span class="sa">f</span><span class="s2">"predictions. Make sure </span><span class="si">{</span><span class="n">attr</span><span class="si">}</span><span class="s2"> is a graph-level or"</span>
+                                <span class="s2">"node-level attribute. Attribute skipped."</span>
+                            <span class="p">)</span>
+                            <span class="k">pass</span>
+                <span class="n">attributes</span><span class="p">[</span><span class="n">attr</span><span class="p">]</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">attribute</span><span class="p">)</span>
+
+        <span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
+            <span class="p">[</span><span class="n">predictions</span><span class="p">]</span>
+            <span class="o">+</span> <span class="p">[</span>
+                <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">values</span><span class="p">)[:,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span>
+                <span class="k">for</span> <span class="n">values</span> <span class="ow">in</span> <span class="n">attributes</span><span class="o">.</span><span class="n">values</span><span class="p">()</span>
+            <span class="p">],</span>
+            <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="n">results</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">(</span>
+            <span class="n">data</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="n">prediction_columns</span> <span class="o">+</span> <span class="n">additional_attributes</span>
+        <span class="p">)</span>
+        <span class="k">return</span> <span class="n">results</span></div>
+
+
+<div class="viewcode-block" id="Model.save">
+<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.save">[docs]</a>
+    <span class="k">def</span> <span class="nf">save</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Save entire model to `path`."""</span>
+        <span class="k">if</span> <span class="ow">not</span> <span class="n">path</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s2">".pth"</span><span class="p">):</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
+                <span class="s2">"It is recommended to use the .pth suffix for model files."</span>
+            <span class="p">)</span>
+        <span class="n">dirname</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">dirname</span><span class="p">:</span>
+            <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">dirname</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+        <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cpu</span><span class="p">(),</span> <span class="n">path</span><span class="p">,</span> <span class="n">pickle_module</span><span class="o">=</span><span class="n">dill</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Model saved to </span><span class="si">{</span><span class="n">path</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="Model.load">
+<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.load">[docs]</a>
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">"Model"</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Load entire model from `path`."""</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">pickle_module</span><span class="o">=</span><span class="n">dill</span><span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="Model.save_state_dict">
+<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.save_state_dict">[docs]</a>
+    <span class="k">def</span> <span class="nf">save_state_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Save model `state_dict` to `path`."""</span>
+        <span class="k">if</span> <span class="ow">not</span> <span class="n">path</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s2">".pth"</span><span class="p">):</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
+                <span class="s2">"It is recommended to use the .pth suffix for state_dict files."</span>
+            <span class="p">)</span>
+        <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">path</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Model state_dict saved to </span><span class="si">{</span><span class="n">path</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="Model.load_state_dict">
+<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.load_state_dict">[docs]</a>
+    <span class="k">def</span> <span class="nf">load_state_dict</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Dict</span><span class="p">],</span> <span class="o">**</span><span class="n">kargs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Any</span><span class="p">]</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">"Model"</span><span class="p">:</span>  <span class="c1"># pylint: disable=arguments-differ</span>
+<span class="w">        </span><span class="sd">"""Load model `state_dict` from `path`."""</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
+            <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">state_dict</span> <span class="o">=</span> <span class="n">path</span>
+        <span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="o">**</span><span class="n">kargs</span><span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="Model.from_config">
+<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.from_config">[docs]</a>
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">from_config</span><span class="p">(</span>  <span class="c1"># type: ignore[override]</span>
+        <span class="bp">cls</span><span class="p">,</span>
+        <span class="n">source</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">ModelConfig</span><span class="p">,</span> <span class="nb">str</span><span class="p">],</span>
+        <span class="n">trust</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">load_modules</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">"Model"</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct `Model` instance from `source` configuration.</span>
+
+<span class="sd">        Arguments:</span>
+<span class="sd">            trust: Whether to trust the ModelConfig file enough to `eval(...)`</span>
+<span class="sd">                any lambda function expressions contained.</span>
+<span class="sd">            load_modules: List of modules used in the definition of the model</span>
+<span class="sd">                which, as a consequence, need to be loaded into the global</span>
+<span class="sd">                namespace. Defaults to loading `torch`.</span>
+
+<span class="sd">        Raises:</span>
+<span class="sd">            ValueError: If the ModelConfig contains lambda functions but</span>
+<span class="sd">                `trust = False`.</span>
+<span class="sd">        """</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">source</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
+            <span class="n">source</span> <span class="o">=</span> <span class="n">ModelConfig</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">source</span><span class="p">)</span>
+
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span>
+            <span class="n">source</span><span class="p">,</span> <span class="n">ModelConfig</span>
+        <span class="p">),</span> <span class="sa">f</span><span class="s2">"Argument `source` of type (</span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">source</span><span class="p">)</span><span class="si">}</span><span class="s2">) is not a `ModelConfig"</span>
+
+        <span class="k">return</span> <span class="n">source</span><span class="o">.</span><span class="n">_construct_model</span><span class="p">(</span><span class="n">trust</span><span class="p">,</span> <span class="n">load_modules</span><span class="p">)</span></div>
+</div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/standard_model.html b/_modules/graphnet/models/standard_model.html
new file mode 100644
index 000000000..cb57cb523
--- /dev/null
+++ b/_modules/graphnet/models/standard_model.html
@@ -0,0 +1,604 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.standard_model &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" />
+    <script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../about.html" />
+    <link rel="index" title="Index" href="../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/standard_model" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.standard_model </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../"versions.json"",
+        target_loc = "../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-standard-model--page-root">Source code for graphnet.models.standard_model</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Standard model class(es)."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span>
+
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span>
+<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">ModuleList</span>
+<span class="kn">from</span> <span class="nn">torch.optim</span> <span class="kn">import</span> <span class="n">Adam</span>
+<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.graphs</span> <span class="kn">import</span> <span class="n">GraphDefinition</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.gnn.gnn</span> <span class="kn">import</span> <span class="n">GNN</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.model</span> <span class="kn">import</span> <span class="n">Model</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.task</span> <span class="kn">import</span> <span class="n">Task</span>
+
+
+<div class="viewcode-block" id="StandardModel">
+<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel">[docs]</a>
+<span class="k">class</span> <span class="nc">StandardModel</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Main class for standard models in graphnet.</span>
+
+<span class="sd">    This class chains together the different elements of a complete GNN-based</span>
+<span class="sd">    model (detector read-in, GNN architecture, and task-specific read-outs).</span>
+<span class="sd">    """</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="o">*</span><span class="p">,</span>
+        <span class="n">graph_definition</span><span class="p">:</span> <span class="n">GraphDefinition</span><span class="p">,</span>
+        <span class="n">gnn</span><span class="p">:</span> <span class="n">GNN</span><span class="p">,</span>
+        <span class="n">tasks</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Task</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="n">Task</span><span class="p">]],</span>
+        <span class="n">optimizer_class</span><span class="p">:</span> <span class="nb">type</span> <span class="o">=</span> <span class="n">Adam</span><span class="p">,</span>
+        <span class="n">optimizer_kwargs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">scheduler_class</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">type</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">scheduler_kwargs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">scheduler_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct `StandardModel`."""</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>
+
+        <span class="c1"># Check(s)</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">tasks</span><span class="p">,</span> <span class="n">Task</span><span class="p">):</span>
+            <span class="n">tasks</span> <span class="o">=</span> <span class="p">[</span><span class="n">tasks</span><span class="p">]</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">tasks</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">))</span>
+        <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">task</span><span class="p">,</span> <span class="n">Task</span><span class="p">)</span> <span class="k">for</span> <span class="n">task</span> <span class="ow">in</span> <span class="n">tasks</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">graph_definition</span><span class="p">,</span> <span class="n">GraphDefinition</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">gnn</span><span class="p">,</span> <span class="n">GNN</span><span class="p">)</span>
+
+        <span class="c1"># Member variable(s)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_graph_definition</span> <span class="o">=</span> <span class="n">graph_definition</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_gnn</span> <span class="o">=</span> <span class="n">gnn</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_tasks</span> <span class="o">=</span> <span class="n">ModuleList</span><span class="p">(</span><span class="n">tasks</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_optimizer_class</span> <span class="o">=</span> <span class="n">optimizer_class</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_optimizer_kwargs</span> <span class="o">=</span> <span class="n">optimizer_kwargs</span> <span class="ow">or</span> <span class="nb">dict</span><span class="p">()</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_scheduler_class</span> <span class="o">=</span> <span class="n">scheduler_class</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_scheduler_kwargs</span> <span class="o">=</span> <span class="n">scheduler_kwargs</span> <span class="ow">or</span> <span class="nb">dict</span><span class="p">()</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_scheduler_config</span> <span class="o">=</span> <span class="n">scheduler_config</span> <span class="ow">or</span> <span class="nb">dict</span><span class="p">()</span>
+
+        <span class="c1"># set dtype of GNN from graph_definition</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_gnn</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_graph_definition</span><span class="o">.</span><span class="n">_dtype</span><span class="p">)</span>
+
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">target_labels</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Return target label."""</span>
+        <span class="k">return</span> <span class="p">[</span><span class="n">label</span> <span class="k">for</span> <span class="n">task</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tasks</span> <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">task</span><span class="o">.</span><span class="n">_target_labels</span><span class="p">]</span>
+
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">prediction_labels</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Return prediction labels."""</span>
+        <span class="k">return</span> <span class="p">[</span>
+            <span class="n">label</span> <span class="k">for</span> <span class="n">task</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tasks</span> <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">task</span><span class="o">.</span><span class="n">_prediction_labels</span>
+        <span class="p">]</span>
+
+<div class="viewcode-block" id="StandardModel.configure_optimizers">
+<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.configure_optimizers">[docs]</a>
+    <span class="k">def</span> <span class="nf">configure_optimizers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Configure the model's optimizer(s)."""</span>
+        <span class="n">optimizer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_optimizer_class</span><span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">_optimizer_kwargs</span>
+        <span class="p">)</span>
+        <span class="n">config</span> <span class="o">=</span> <span class="p">{</span>
+            <span class="s2">"optimizer"</span><span class="p">:</span> <span class="n">optimizer</span><span class="p">,</span>
+        <span class="p">}</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_scheduler_class</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">scheduler</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_scheduler_class</span><span class="p">(</span>
+                <span class="n">optimizer</span><span class="p">,</span> <span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">_scheduler_kwargs</span>
+            <span class="p">)</span>
+            <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
+                <span class="p">{</span>
+                    <span class="s2">"lr_scheduler"</span><span class="p">:</span> <span class="p">{</span>
+                        <span class="s2">"scheduler"</span><span class="p">:</span> <span class="n">scheduler</span><span class="p">,</span>
+                        <span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">_scheduler_config</span><span class="p">,</span>
+                    <span class="p">},</span>
+                <span class="p">}</span>
+            <span class="p">)</span>
+        <span class="k">return</span> <span class="n">config</span></div>
+
+
+<div class="viewcode-block" id="StandardModel.forward">
+<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.forward">[docs]</a>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">]]:</span>
+<span class="w">        </span><span class="sd">"""Forward pass, chaining model components."""</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">Data</span><span class="p">)</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_gnn</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
+        <span class="n">preds</span> <span class="o">=</span> <span class="p">[</span><span class="n">task</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">task</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tasks</span><span class="p">]</span>
+        <span class="k">return</span> <span class="n">preds</span></div>
+
+
+<div class="viewcode-block" id="StandardModel.shared_step">
+<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.shared_step">[docs]</a>
+    <span class="k">def</span> <span class="nf">shared_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Perform shared step.</span>
+
+<span class="sd">        Applies the forward pass and the following loss calculation, shared</span>
+<span class="sd">        between the training and validation step.</span>
+<span class="sd">        """</span>
+        <span class="n">preds</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
+        <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">loss</span></div>
+
+
+<div class="viewcode-block" id="StandardModel.training_step">
+<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.training_step">[docs]</a>
+    <span class="k">def</span> <span class="nf">training_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">train_batch</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Perform training step."""</span>
+        <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_step</span><span class="p">(</span><span class="n">train_batch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span>
+            <span class="s2">"train_loss"</span><span class="p">,</span>
+            <span class="n">loss</span><span class="p">,</span>
+            <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_batch_size</span><span class="p">(</span><span class="n">train_batch</span><span class="p">),</span>
+            <span class="n">prog_bar</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
+            <span class="n">on_epoch</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
+            <span class="n">on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
+            <span class="n">sync_dist</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
+        <span class="p">)</span>
+        <span class="k">return</span> <span class="n">loss</span></div>
+
+
+<div class="viewcode-block" id="StandardModel.validation_step">
+<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.validation_step">[docs]</a>
+    <span class="k">def</span> <span class="nf">validation_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val_batch</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Perform validation step."""</span>
+        <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_step</span><span class="p">(</span><span class="n">val_batch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span>
+            <span class="s2">"val_loss"</span><span class="p">,</span>
+            <span class="n">loss</span><span class="p">,</span>
+            <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_batch_size</span><span class="p">(</span><span class="n">val_batch</span><span class="p">),</span>
+            <span class="n">prog_bar</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
+            <span class="n">on_epoch</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
+            <span class="n">on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
+            <span class="n">sync_dist</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
+        <span class="p">)</span>
+        <span class="k">return</span> <span class="n">loss</span></div>
+
+
+<div class="viewcode-block" id="StandardModel.compute_loss">
+<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.compute_loss">[docs]</a>
+    <span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">verbose</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Compute and sum losses across tasks."""</span>
+        <span class="n">losses</span> <span class="o">=</span> <span class="p">[</span>
+            <span class="n">task</span><span class="o">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span>
+            <span class="k">for</span> <span class="n">task</span><span class="p">,</span> <span class="n">pred</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_tasks</span><span class="p">,</span> <span class="n">preds</span><span class="p">)</span>
+        <span class="p">]</span>
+        <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">losses</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
+        <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span>
+            <span class="n">loss</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">loss</span> <span class="ow">in</span> <span class="n">losses</span>
+        <span class="p">),</span> <span class="s2">"Please reduce loss for each task separately"</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">losses</span><span class="p">))</span></div>
+
+
+    <span class="k">def</span> <span class="nf">_get_batch_size</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">numel</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">batch</span><span class="p">))</span>
+
+<div class="viewcode-block" id="StandardModel.inference">
+<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.inference">[docs]</a>
+    <span class="k">def</span> <span class="nf">inference</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Activate inference mode."""</span>
+        <span class="k">for</span> <span class="n">task</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tasks</span><span class="p">:</span>
+            <span class="n">task</span><span class="o">.</span><span class="n">inference</span><span class="p">()</span></div>
+
+
+<div class="viewcode-block" id="StandardModel.train">
+<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.train">[docs]</a>
+    <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">"Model"</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Deactivate inference mode."""</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">mode</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">mode</span><span class="p">:</span>
+            <span class="k">for</span> <span class="n">task</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tasks</span><span class="p">:</span>
+                <span class="n">task</span><span class="o">.</span><span class="n">train_eval</span><span class="p">()</span>
+        <span class="k">return</span> <span class="bp">self</span></div>
+
+
+<div class="viewcode-block" id="StandardModel.predict">
+<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.predict">[docs]</a>
+    <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span>
+        <span class="n">gpus</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">distribution_strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"auto"</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Return predictions for `dataloader`."""</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">inference</span><span class="p">()</span>
+        <span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
+            <span class="n">dataloader</span><span class="o">=</span><span class="n">dataloader</span><span class="p">,</span>
+            <span class="n">gpus</span><span class="o">=</span><span class="n">gpus</span><span class="p">,</span>
+            <span class="n">distribution_strategy</span><span class="o">=</span><span class="n">distribution_strategy</span><span class="p">,</span>
+        <span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="StandardModel.predict_as_dataframe">
+<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.predict_as_dataframe">[docs]</a>
+    <span class="k">def</span> <span class="nf">predict_as_dataframe</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span>
+        <span class="n">prediction_columns</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="o">*</span><span class="p">,</span>
+        <span class="n">additional_attributes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">gpus</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">distribution_strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"auto"</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Return predictions for `dataloader` as a DataFrame.</span>
+
+<span class="sd">        Include `additional_attributes` as additional columns in the output</span>
+<span class="sd">        DataFrame.</span>
+<span class="sd">        """</span>
+        <span class="k">if</span> <span class="n">prediction_columns</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">prediction_columns</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prediction_labels</span>
+        <span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">predict_as_dataframe</span><span class="p">(</span>
+            <span class="n">dataloader</span><span class="o">=</span><span class="n">dataloader</span><span class="p">,</span>
+            <span class="n">prediction_columns</span><span class="o">=</span><span class="n">prediction_columns</span><span class="p">,</span>
+            <span class="n">additional_attributes</span><span class="o">=</span><span class="n">additional_attributes</span><span class="p">,</span>
+            <span class="n">gpus</span><span class="o">=</span><span class="n">gpus</span><span class="p">,</span>
+            <span class="n">distribution_strategy</span><span class="o">=</span><span class="n">distribution_strategy</span><span class="p">,</span>
+        <span class="p">)</span></div>
+</div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/task/classification.html b/_modules/graphnet/models/task/classification.html
new file mode 100644
index 000000000..b3c6c9440
--- /dev/null
+++ b/_modules/graphnet/models/task/classification.html
@@ -0,0 +1,411 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.task.classification &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/task/classification" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.task.classification </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-task-classification--page-root">Source code for graphnet.models.task.classification</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Classification-specific `Model` class(es)."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span>
+
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.models.task</span> <span class="kn">import</span> <span class="n">Task</span><span class="p">,</span> <span class="n">IdentityTask</span>
+
+
+<div class="viewcode-block" id="MulticlassClassificationTask">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.classification.html#graphnet.models.task.classification.MulticlassClassificationTask">[docs]</a>
+<span class="k">class</span> <span class="nc">MulticlassClassificationTask</span><span class="p">(</span><span class="n">IdentityTask</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""General task for classifying any number of classes.</span>
+
+<span class="sd">    Requires the same number of input features as the number of classes being</span>
+<span class="sd">    predicted. Returns the untransformed latent features, which are interpreted</span>
+<span class="sd">    as the logits for each class being classified.</span>
+<span class="sd">    """</span></div>
+
+
+
+<div class="viewcode-block" id="BinaryClassificationTask">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTask">[docs]</a>
+<span class="k">class</span> <span class="nc">BinaryClassificationTask</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Performs binary classification."""</span>
+
+    <span class="c1"># Requires one feature, logit for being signal class.</span>
+    <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">1</span>
+    <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"target"</span><span class="p">]</span>
+    <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"target_pred"</span><span class="p">]</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="c1"># transform probability of being muon</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div>
+
+
+
+<div class="viewcode-block" id="BinaryClassificationTaskLogits">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTaskLogits">[docs]</a>
+<span class="k">class</span> <span class="nc">BinaryClassificationTaskLogits</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Performs binary classification form logits."""</span>
+
+    <span class="c1"># Requires one feature, logit for being signal class.</span>
+    <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">1</span>
+    <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"target"</span><span class="p">]</span>
+    <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"target_pred"</span><span class="p">]</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">x</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/task/reconstruction.html b/_modules/graphnet/models/task/reconstruction.html
new file mode 100644
index 000000000..ccb644a97
--- /dev/null
+++ b/_modules/graphnet/models/task/reconstruction.html
@@ -0,0 +1,609 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.task.reconstruction &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/task/reconstruction" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.task.reconstruction </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-task-reconstruction--page-root">Source code for graphnet.models.task.reconstruction</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Reconstruction-specific `Model` class(es)."""</span>
+
+<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.models.task</span> <span class="kn">import</span> <span class="n">Task</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.maths</span> <span class="kn">import</span> <span class="n">eps_like</span>
+
+
+<div class="viewcode-block" id="AzimuthReconstructionWithKappa">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa">[docs]</a>
+<span class="k">class</span> <span class="nc">AzimuthReconstructionWithKappa</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Reconstructs azimuthal angle and associated kappa (1/var)."""</span>
+
+    <span class="c1"># Requires two features: untransformed points in (x,y)-space.</span>
+    <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"azimuth"</span><span class="p">]</span>
+    <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"azimuth_pred"</span><span class="p">,</span> <span class="s2">"azimuth_kappa"</span><span class="p">]</span>
+    <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">2</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="c1"># Transform outputs to angle and prepare prediction</span>
+        <span class="n">kappa</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">vector_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">eps_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+        <span class="n">angle</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">atan2</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span>
+        <span class="n">angle</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span>
+            <span class="n">angle</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">,</span> <span class="n">angle</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">pi</span><span class="p">,</span> <span class="n">angle</span>
+        <span class="p">)</span>  <span class="c1"># atan(y,x) -&gt; [-pi, pi]</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">angle</span><span class="p">,</span> <span class="n">kappa</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></div>
+
+
+
+<div class="viewcode-block" id="AzimuthReconstruction">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstruction">[docs]</a>
+<span class="k">class</span> <span class="nc">AzimuthReconstruction</span><span class="p">(</span><span class="n">AzimuthReconstructionWithKappa</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Reconstructs azimuthal angle."""</span>
+
+    <span class="c1"># Requires two features: untransformed points in (x,y)-space.</span>
+    <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"azimuth"</span><span class="p">]</span>
+    <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"azimuth_pred"</span><span class="p">]</span>
+    <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">2</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="c1"># Transform outputs to angle and prepare prediction</span>
+        <span class="n">res</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+        <span class="n">angle</span> <span class="o">=</span> <span class="n">res</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
+        <span class="n">kappa</span> <span class="o">=</span> <span class="n">res</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span>
+        <span class="n">sigma</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">/</span> <span class="n">kappa</span><span class="p">)</span>
+        <span class="n">beta</span> <span class="o">=</span> <span class="mf">1e-3</span>
+        <span class="n">kl_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">sigma</span><span class="o">**</span><span class="mi">2</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">sigma</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_regularisation_loss</span> <span class="o">+=</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">kl_loss</span>
+        <span class="k">return</span> <span class="n">angle</span></div>
+
+
+
+<div class="viewcode-block" id="DirectionReconstructionWithKappa">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa">[docs]</a>
+<span class="k">class</span> <span class="nc">DirectionReconstructionWithKappa</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Reconstructs direction with kappa from the 3D-vMF distribution."""</span>
+
+    <span class="c1"># Requires three features: untransformed points in (x,y,z)-space.</span>
+    <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span>
+        <span class="s2">"direction"</span>
+    <span class="p">]</span>  <span class="c1"># contains dir_x, dir_y, dir_z see https://github.com/graphnet-team/graphnet/blob/95309556cfd46a4046bc4bd7609888aab649e295/src/graphnet/training/labels.py#L29</span>
+    <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span>
+        <span class="s2">"dir_x_pred"</span><span class="p">,</span>
+        <span class="s2">"dir_y_pred"</span><span class="p">,</span>
+        <span class="s2">"dir_z_pred"</span><span class="p">,</span>
+        <span class="s2">"direction_kappa"</span><span class="p">,</span>
+    <span class="p">]</span>
+    <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">3</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="c1"># Transform outputs to angle and prepare prediction</span>
+        <span class="n">kappa</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">vector_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">eps_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+        <span class="n">vec_x</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">kappa</span>
+        <span class="n">vec_y</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">kappa</span>
+        <span class="n">vec_z</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="n">kappa</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">vec_x</span><span class="p">,</span> <span class="n">vec_y</span><span class="p">,</span> <span class="n">vec_z</span><span class="p">,</span> <span class="n">kappa</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></div>
+
+
+
+<div class="viewcode-block" id="ZenithReconstruction">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstruction">[docs]</a>
+<span class="k">class</span> <span class="nc">ZenithReconstruction</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Reconstructs zenith angle."""</span>
+
+    <span class="c1"># Requires two features: zenith angle itself.</span>
+    <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"zenith"</span><span class="p">]</span>
+    <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"zenith_pred"</span><span class="p">]</span>
+    <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">1</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="c1"># Transform outputs to angle and prepare prediction</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">1</span><span class="p">])</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">pi</span></div>
+
+
+
+<div class="viewcode-block" id="ZenithReconstructionWithKappa">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa">[docs]</a>
+<span class="k">class</span> <span class="nc">ZenithReconstructionWithKappa</span><span class="p">(</span><span class="n">ZenithReconstruction</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Reconstructs zenith angle and associated kappa (1/var)."""</span>
+
+    <span class="c1"># Requires one feature in addition to `ZenithReconstruction`: kappa (unceratinty; 1/variance).</span>
+    <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"zenith"</span><span class="p">]</span>
+    <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"zenith_pred"</span><span class="p">,</span> <span class="s2">"zenith_kappa"</span><span class="p">]</span>
+    <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">2</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="c1"># Transform outputs to angle and prepare prediction</span>
+        <span class="n">angle</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_forward</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
+        <span class="n">kappa</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span> <span class="o">+</span> <span class="n">eps_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">angle</span><span class="p">,</span> <span class="n">kappa</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></div>
+
+
+
+<div class="viewcode-block" id="EnergyReconstruction">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstruction">[docs]</a>
+<span class="k">class</span> <span class="nc">EnergyReconstruction</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Reconstructs energy using stable method."""</span>
+
+    <span class="c1"># Requires one feature: untransformed energy</span>
+    <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"energy"</span><span class="p">]</span>
+    <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"energy_pred"</span><span class="p">]</span>
+    <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">1</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="c1"># Transform to positive energy domain avoiding `-inf` in `log10`</span>
+        <span class="c1"># Transform, thereby preventing overflow and underflow error.</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">softplus</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="mf">0.05</span><span class="p">)</span> <span class="o">+</span> <span class="n">eps_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div>
+
+
+
+<div class="viewcode-block" id="EnergyReconstructionWithPower">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithPower">[docs]</a>
+<span class="k">class</span> <span class="nc">EnergyReconstructionWithPower</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Reconstructs energy."""</span>
+
+    <span class="c1"># Requires one feature: untransformed energy</span>
+    <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"energy"</span><span class="p">]</span>
+    <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"energy_pred"</span><span class="p">]</span>
+    <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">1</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="c1"># Transform energy</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mf">1.0</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span></div>
+
+
+
+<div class="viewcode-block" id="EnergyReconstructionWithUncertainty">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty">[docs]</a>
+<span class="k">class</span> <span class="nc">EnergyReconstructionWithUncertainty</span><span class="p">(</span><span class="n">EnergyReconstruction</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Reconstructs energy and associated uncertainty (log(var))."""</span>
+
+    <span class="c1"># Requires one feature in addition to `EnergyReconstruction`: log-variance (uncertainty).</span>
+    <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"energy"</span><span class="p">]</span>
+    <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"energy_pred"</span><span class="p">,</span> <span class="s2">"energy_sigma"</span><span class="p">]</span>
+    <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">2</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="c1"># Transform energy</span>
+        <span class="n">energy</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_forward</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
+        <span class="n">log_var</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span>
+        <span class="n">pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">energy</span><span class="p">,</span> <span class="n">log_var</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">pred</span></div>
+
+
+
+<div class="viewcode-block" id="VertexReconstruction">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.VertexReconstruction">[docs]</a>
+<span class="k">class</span> <span class="nc">VertexReconstruction</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Reconstructs vertex position and time."""</span>
+
+    <span class="c1"># Requires four features, x, y, z, and t.</span>
+    <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"vertex"</span><span class="p">]</span>
+    <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span>
+        <span class="s2">"position_x_pred"</span><span class="p">,</span>
+        <span class="s2">"position_y_pred"</span><span class="p">,</span>
+        <span class="s2">"position_z_pred"</span><span class="p">,</span>
+        <span class="s2">"interaction_time_pred"</span><span class="p">,</span>
+    <span class="p">]</span>
+    <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">4</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="c1"># Scale xyz to roughly the right order of magnitude, leave time</span>
+        <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1e2</span>
+        <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1e2</span>
+        <span class="n">x</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1e2</span>
+
+        <span class="k">return</span> <span class="n">x</span></div>
+
+
+
+<div class="viewcode-block" id="PositionReconstruction">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.PositionReconstruction">[docs]</a>
+<span class="k">class</span> <span class="nc">PositionReconstruction</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Reconstructs vertex position."""</span>
+
+    <span class="c1"># Requires three features, x, y, and z.</span>
+    <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"position"</span><span class="p">]</span>
+    <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span>
+        <span class="s2">"position_x_pred"</span><span class="p">,</span>
+        <span class="s2">"position_y_pred"</span><span class="p">,</span>
+        <span class="s2">"position_z_pred"</span><span class="p">,</span>
+    <span class="p">]</span>
+    <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">3</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="c1"># Scale to roughly the right order of magnitude</span>
+        <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1e2</span>
+        <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1e2</span>
+        <span class="n">x</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1e2</span>
+
+        <span class="k">return</span> <span class="n">x</span></div>
+
+
+
+<div class="viewcode-block" id="TimeReconstruction">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.TimeReconstruction">[docs]</a>
+<span class="k">class</span> <span class="nc">TimeReconstruction</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Reconstructs time."""</span>
+
+    <span class="c1"># Requires one feature, time.</span>
+    <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"interaction_time"</span><span class="p">]</span>
+    <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"interaction_time_pred"</span><span class="p">]</span>
+    <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">1</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="c1"># Leave as it is</span>
+        <span class="k">return</span> <span class="n">x</span></div>
+
+
+
+<div class="viewcode-block" id="InelasticityReconstruction">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.InelasticityReconstruction">[docs]</a>
+<span class="k">class</span> <span class="nc">InelasticityReconstruction</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Reconstructs interaction inelasticity.</span>
+
+<span class="sd">    That is, 1-(track energy / hadronic energy).</span>
+<span class="sd">    """</span>
+
+    <span class="c1"># Requires one features: inelasticity itself</span>
+    <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"inelasticity"</span><span class="p">]</span>
+    <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"inelasticity_pred"</span><span class="p">]</span>
+    <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">1</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="c1"># Transform output to unit range</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/task/task.html b/_modules/graphnet/models/task/task.html
new file mode 100644
index 000000000..c43cb2f73
--- /dev/null
+++ b/_modules/graphnet/models/task/task.html
@@ -0,0 +1,688 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.task.task &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/task/task" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.task.task </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-task-task--page-root">Source code for graphnet.models.task.task</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Base physics task-specific `Model` class(es)."""</span>
+
+<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">TYPE_CHECKING</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Optional</span>
+<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
+
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span>
+<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">Linear</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+
+<span class="k">if</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span>
+    <span class="c1"># Avoid cyclic dependency</span>
+    <span class="kn">from</span> <span class="nn">graphnet.training.loss_functions</span> <span class="kn">import</span> <span class="n">LossFunction</span>  <span class="c1"># type: ignore[attr-defined]</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.decorators</span> <span class="kn">import</span> <span class="n">final</span>
+
+
+<div class="viewcode-block" id="Task">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.task.html#graphnet.models.task.task.Task">[docs]</a>
+<span class="k">class</span> <span class="nc">Task</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Base class for all reconstruction and classification tasks."""</span>
+
+    <span class="nd">@property</span>
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">nb_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Return number of inputs assumed by task."""</span>
+
+    <span class="nd">@property</span>
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">default_target_labels</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Return default target labels."""</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_default_target_labels</span>
+
+    <span class="nd">@property</span>
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">default_prediction_labels</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Return default prediction labels."""</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_default_prediction_labels</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="o">*</span><span class="p">,</span>
+        <span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="n">loss_function</span><span class="p">:</span> <span class="s2">"LossFunction"</span><span class="p">,</span>
+        <span class="n">target_labels</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">prediction_labels</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">transform_prediction_and_target</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">transform_target</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">transform_inference</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">transform_support</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">loss_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct `Task`.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            hidden_size: The number of nodes in the layer feeding into this</span>
+<span class="sd">                tasks, used to construct the affine transformation to the</span>
+<span class="sd">                predicted quantity.</span>
+<span class="sd">            loss_function: Loss function appropriate to the task.</span>
+<span class="sd">            target_labels: Name(s) of the quantity/-ies being predicted, used</span>
+<span class="sd">                to extract the  target tensor(s) from the `Data` object in</span>
+<span class="sd">                `.compute_loss(...)`.</span>
+<span class="sd">            prediction_labels: The name(s) of each column that is predicted by</span>
+<span class="sd">                the model during inference. If not given, the name will auto</span>
+<span class="sd">                matically be set to `target_label + _pred`.</span>
+<span class="sd">            transform_prediction_and_target: Optional function to transform</span>
+<span class="sd">                both the predicted and target tensor before passing them to the</span>
+<span class="sd">                loss function. Useful e.g. for having the model predict</span>
+<span class="sd">                quantities on a physical scale, but transforming this scale to</span>
+<span class="sd">                O(1) for a numerically stable loss computation.</span>
+<span class="sd">            transform_target: Optional function to transform only the target</span>
+<span class="sd">                tensor before passing it, and the predicted tensor, to the loss</span>
+<span class="sd">                function. Useful e.g. for having the model predict a</span>
+<span class="sd">                transformed version of the target quantity, e.g. the log10-</span>
+<span class="sd">                scaled energy, rather than the physical quantity itself. Used</span>
+<span class="sd">                in conjunction with `transform_inference` to perform the</span>
+<span class="sd">                inverse transform on the predicted quantity to recover the</span>
+<span class="sd">                physical scale.</span>
+<span class="sd">            transform_inference: Optional function to inverse-transform the</span>
+<span class="sd">                model prediction to recover a physical scale. Used in</span>
+<span class="sd">                conjunction with `transform_target`.</span>
+<span class="sd">            transform_support: Optional tuple to specify minimum and maximum</span>
+<span class="sd">                of the range of validity for the inverse transforms</span>
+<span class="sd">                `transform_target` and `transform_inference` in case this is</span>
+<span class="sd">                restricted. By default the invertibility of `transform_target`</span>
+<span class="sd">                is tested on the range [-1e6, 1e6].</span>
+<span class="sd">            loss_weight: Name of the attribute in `data` containing per-event</span>
+<span class="sd">                loss weights.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
+        <span class="c1"># Check(s)</span>
+        <span class="k">if</span> <span class="n">target_labels</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">target_labels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_target_labels</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">target_labels</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
+            <span class="n">target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">target_labels</span><span class="p">]</span>
+
+        <span class="k">if</span> <span class="n">prediction_labels</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">prediction_labels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_prediction_labels</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">prediction_labels</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
+            <span class="n">prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">prediction_labels</span><span class="p">]</span>
+
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">target_labels</span><span class="p">,</span> <span class="n">List</span><span class="p">)</span>  <span class="c1"># mypy</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">prediction_labels</span><span class="p">,</span> <span class="n">List</span><span class="p">)</span>  <span class="c1"># mypy</span>
+        <span class="c1"># Member variables</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_regularisation_loss</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_target_labels</span> <span class="o">=</span> <span class="n">target_labels</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_prediction_labels</span> <span class="o">=</span> <span class="n">prediction_labels</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_loss_function</span> <span class="o">=</span> <span class="n">loss_function</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_inference</span> <span class="o">=</span> <span class="kc">False</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight</span> <span class="o">=</span> <span class="n">loss_weight</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_transform_prediction_training</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[</span>
+            <span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">Tensor</span>
+        <span class="p">]</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_transform_prediction_inference</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[</span>
+            <span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">Tensor</span>
+        <span class="p">]</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_transform_target</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_validate_and_set_transforms</span><span class="p">(</span>
+            <span class="n">transform_prediction_and_target</span><span class="p">,</span>
+            <span class="n">transform_target</span><span class="p">,</span>
+            <span class="n">transform_inference</span><span class="p">,</span>
+            <span class="n">transform_support</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Mapping from last hidden layer to required size of input</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_affine</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_inputs</span><span class="p">)</span>
+
+<div class="viewcode-block" id="Task.forward">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.task.html#graphnet.models.task.task.Task.forward">[docs]</a>
+    <span class="nd">@final</span>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Forward pass."""</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_regularisation_loss</span> <span class="o">=</span> <span class="mi">0</span>  <span class="c1"># Reset</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_affine</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transform_prediction</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div>
+
+
+    <span class="nd">@final</span>
+    <span class="k">def</span> <span class="nf">_transform_prediction</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">]</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">]:</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_inference</span><span class="p">:</span>
+            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transform_prediction_inference</span><span class="p">(</span><span class="n">prediction</span><span class="p">)</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transform_prediction_training</span><span class="p">(</span><span class="n">prediction</span><span class="p">)</span>
+
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Syntax like `.forward`, for implentation in inheriting classes."""</span>
+
+<div class="viewcode-block" id="Task.compute_loss">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.task.html#graphnet.models.task.task.Task.compute_loss">[docs]</a>
+    <span class="nd">@final</span>
+    <span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">],</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Compute loss of `pred` wrt.</span>
+
+<span class="sd">        target labels in `data`.</span>
+<span class="sd">        """</span>
+        <span class="n">target</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span>
+            <span class="p">[</span><span class="n">data</span><span class="p">[</span><span class="n">label</span><span class="p">]</span> <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_target_labels</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span>
+        <span class="p">)</span>
+        <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transform_target</span><span class="p">(</span><span class="n">target</span><span class="p">)</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">weights</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight</span><span class="p">]</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">weights</span> <span class="o">=</span> <span class="kc">None</span>
+        <span class="n">loss</span> <span class="o">=</span> <span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_loss_function</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="n">weights</span><span class="p">)</span>
+            <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">_regularisation_loss</span>
+        <span class="p">)</span>
+        <span class="k">return</span> <span class="n">loss</span></div>
+
+
+<div class="viewcode-block" id="Task.inference">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.task.html#graphnet.models.task.task.Task.inference">[docs]</a>
+    <span class="nd">@final</span>
+    <span class="k">def</span> <span class="nf">inference</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Activate inference mode."""</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_inference</span> <span class="o">=</span> <span class="kc">True</span></div>
+
+
+<div class="viewcode-block" id="Task.train_eval">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.task.html#graphnet.models.task.task.Task.train_eval">[docs]</a>
+    <span class="nd">@final</span>
+    <span class="k">def</span> <span class="nf">train_eval</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Deactivate inference mode."""</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_inference</span> <span class="o">=</span> <span class="kc">False</span></div>
+
+
+    <span class="nd">@final</span>
+    <span class="k">def</span> <span class="nf">_validate_and_set_transforms</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">transform_prediction_and_target</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Callable</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span>
+        <span class="n">transform_target</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Callable</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span>
+        <span class="n">transform_inference</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Callable</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span>
+        <span class="n">transform_support</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tuple</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Validate and set transforms.</span>
+
+<span class="sd">        Assert that a valid combination of transformation arguments are passed</span>
+<span class="sd">        and update the corresponding functions.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Checks</span>
+        <span class="k">assert</span> <span class="ow">not</span> <span class="p">(</span>
+            <span class="p">(</span><span class="n">transform_prediction_and_target</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span>
+            <span class="ow">and</span> <span class="p">(</span><span class="n">transform_target</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span>
+        <span class="p">),</span> <span class="s2">"Please specify at most one of `transform_prediction_and_target` and `transform_target`"</span>
+        <span class="k">if</span> <span class="p">(</span><span class="n">transform_target</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">!=</span> <span class="p">(</span><span class="n">transform_inference</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">):</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
+                <span class="s2">"Setting one of `transform_target` and `transform_inference`, but not "</span>
+                <span class="s2">"the other."</span>
+            <span class="p">)</span>
+
+        <span class="k">if</span> <span class="n">transform_target</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="k">assert</span> <span class="n">transform_target</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
+            <span class="k">assert</span> <span class="n">transform_inference</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
+
+            <span class="k">if</span> <span class="n">transform_support</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+                <span class="k">assert</span> <span class="n">transform_support</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
+
+                <span class="k">assert</span> <span class="p">(</span>
+                    <span class="nb">len</span><span class="p">(</span><span class="n">transform_support</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span>
+                <span class="p">),</span> <span class="s2">"Please specify min and max for transformation support."</span>
+                <span class="n">x_test</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span>
+                    <span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">transform_support</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">transform_support</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="mi">10</span><span class="p">)</span>
+                <span class="p">)</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="n">x_test</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">logspace</span><span class="p">(</span><span class="o">-</span><span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">12</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
+                <span class="n">x_test</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span>
+                    <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="o">-</span><span class="n">x_test</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">x_test</span><span class="p">])</span>
+                <span class="p">)</span>
+
+            <span class="c1"># Add feature dimension before inference transformation to make it</span>
+            <span class="c1"># match the dimensions of a standard prediction. Remove it again</span>
+            <span class="c1"># before comparison. Temporary</span>
+            <span class="k">try</span><span class="p">:</span>
+                <span class="n">t_test</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">transform_target</span><span class="p">(</span><span class="n">x_test</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
+                <span class="n">t_test</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">transform_inference</span><span class="p">(</span><span class="n">t_test</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
+                <span class="n">valid</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">isfinite</span><span class="p">(</span><span class="n">t_test</span><span class="p">)</span>
+
+                <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">t_test</span><span class="p">[</span><span class="n">valid</span><span class="p">],</span> <span class="n">x_test</span><span class="p">[</span><span class="n">valid</span><span class="p">]),</span> <span class="p">(</span>
+                    <span class="s2">"The provided transforms for targets during training and "</span>
+                    <span class="s2">"predictions during inference are not inverse. Please "</span>
+                    <span class="s2">"adjust transformation functions or support."</span>
+                <span class="p">)</span>
+                <span class="k">del</span> <span class="n">x_test</span><span class="p">,</span> <span class="n">t_test</span><span class="p">,</span> <span class="n">valid</span>
+
+            <span class="k">except</span> <span class="ne">IndexError</span><span class="p">:</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
+                    <span class="s2">"transform_target and/or transform_inference rely on "</span>
+                    <span class="s2">"indexing, which we won't validate. Please make sure that "</span>
+                    <span class="s2">"they are mutually inverse, i.e. that</span><span class="se">\n</span><span class="s2">"</span>
+                    <span class="s2">"  x = transform_inference(transform_target(x))</span><span class="se">\n</span><span class="s2">"</span>
+                    <span class="s2">"for all x that are within your target range."</span>
+                <span class="p">)</span>
+
+        <span class="c1"># Set transforms</span>
+        <span class="k">if</span> <span class="n">transform_prediction_and_target</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_transform_prediction_training</span> <span class="o">=</span> <span class="p">(</span>
+                <span class="n">transform_prediction_and_target</span>
+            <span class="p">)</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_transform_target</span> <span class="o">=</span> <span class="n">transform_prediction_and_target</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">if</span> <span class="n">transform_target</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_transform_target</span> <span class="o">=</span> <span class="n">transform_target</span>
+            <span class="k">if</span> <span class="n">transform_inference</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_transform_prediction_inference</span> <span class="o">=</span> <span class="n">transform_inference</span></div>
+
+
+
+<div class="viewcode-block" id="IdentityTask">
+<a class="viewcode-back" href="../../../../api/graphnet.models.task.task.html#graphnet.models.task.task.IdentityTask">[docs]</a>
+<span class="k">class</span> <span class="nc">IdentityTask</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Identity, or trivial, task."""</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">nb_outputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="n">target_labels</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="n">Any</span><span class="p">],</span>
+        <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span>
+        <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct IdentityTask.</span>
+
+<span class="sd">        Return the `nb_outputs` as a direct, affine transformation of the last</span>
+<span class="sd">        hidden layer.</span>
+<span class="sd">        """</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span> <span class="o">=</span> <span class="n">nb_outputs</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_default_target_labels</span> <span class="o">=</span> <span class="p">(</span>
+            <span class="n">target_labels</span>
+            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">target_labels</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
+            <span class="k">else</span> <span class="p">[</span><span class="n">target_labels</span><span class="p">]</span>
+        <span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span>
+            <span class="sa">f</span><span class="s2">"target_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">_pred"</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_default_target_labels</span><span class="p">))</span>
+        <span class="p">]</span>
+
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+        <span class="c1"># Base class constructor</span>
+
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">default_target_labels</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Return default target labels."""</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_default_target_labels</span>
+
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">default_prediction_labels</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Return default prediction labels."""</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_default_prediction_labels</span>
+
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">nb_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Return number of inputs assumed by task."""</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="c1"># Leave it as is.</span>
+        <span class="k">return</span> <span class="n">x</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/models/utils.html b/_modules/graphnet/models/utils.html
new file mode 100644
index 000000000..35c1b3460
--- /dev/null
+++ b/_modules/graphnet/models/utils.html
@@ -0,0 +1,430 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.models.utils &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" />
+    <script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../about.html" />
+    <link rel="index" title="Index" href="../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/models/utils" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.models.utils </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../"versions.json"",
+        target_loc = "../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-models-utils--page-root">Source code for graphnet.models.utils</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Utility functions for `graphnet.models`."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.nn</span> <span class="kn">import</span> <span class="n">knn_graph</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Batch</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">LongTensor</span>
+
+<span class="kn">from</span> <span class="nn">torch_geometric.utils.homophily</span> <span class="kn">import</span> <span class="n">homophily</span>
+
+
+<div class="viewcode-block" id="calculate_xyzt_homophily">
+<a class="viewcode-back" href="../../../api/graphnet.models.utils.html#graphnet.models.utils.calculate_xyzt_homophily">[docs]</a>
+<span class="k">def</span> <span class="nf">calculate_xyzt_homophily</span><span class="p">(</span>
+    <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Batch</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
+<span class="w">    </span><span class="sd">"""Calculate xyzt-homophily from a batch of graphs.</span>
+
+<span class="sd">    Homophily is a graph scalar quantity that measures the likeness of</span>
+<span class="sd">    variables in nodes. Notice that this calculator assumes a special order of</span>
+<span class="sd">    input features in x.</span>
+
+<span class="sd">    Returns:</span>
+<span class="sd">        Tuple, each element with shape [batch_size,1].</span>
+<span class="sd">    """</span>
+    <span class="n">hx</span> <span class="o">=</span> <span class="n">homophily</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">batch</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
+    <span class="n">hy</span> <span class="o">=</span> <span class="n">homophily</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">batch</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
+    <span class="n">hz</span> <span class="o">=</span> <span class="n">homophily</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">batch</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
+    <span class="n">ht</span> <span class="o">=</span> <span class="n">homophily</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">],</span> <span class="n">batch</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
+    <span class="k">return</span> <span class="n">hx</span><span class="p">,</span> <span class="n">hy</span><span class="p">,</span> <span class="n">hz</span><span class="p">,</span> <span class="n">ht</span></div>
+
+
+
+<div class="viewcode-block" id="calculate_distance_matrix">
+<a class="viewcode-back" href="../../../api/graphnet.models.utils.html#graphnet.models.utils.calculate_distance_matrix">[docs]</a>
+<span class="k">def</span> <span class="nf">calculate_distance_matrix</span><span class="p">(</span><span class="n">xyz_coords</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Calculate the matrix of pairwise distances between pulses.</span>
+
+<span class="sd">    Args:</span>
+<span class="sd">        xyz_coords: (x,y,z)-coordinates of pulses, of shape [nb_doms, 3].</span>
+
+<span class="sd">    Returns:</span>
+<span class="sd">        Matrix of pairwise distances, of shape [nb_doms, nb_doms]</span>
+<span class="sd">    """</span>
+    <span class="n">diff</span> <span class="o">=</span> <span class="n">xyz_coords</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="n">xyz_coords</span><span class="o">.</span><span class="n">T</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
+    <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">diff</span><span class="o">**</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span></div>
+
+
+
+<div class="viewcode-block" id="knn_graph_batch">
+<a class="viewcode-back" href="../../../api/graphnet.models.utils.html#graphnet.models.utils.knn_graph_batch">[docs]</a>
+<span class="k">def</span> <span class="nf">knn_graph_batch</span><span class="p">(</span><span class="n">batch</span><span class="p">:</span> <span class="n">Batch</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Batch</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Calculate k-nearest-neighbours with individual k for each batch event.</span>
+
+<span class="sd">    Args:</span>
+<span class="sd">        batch: Batch of events.</span>
+<span class="sd">        k: A list of k's.</span>
+<span class="sd">        columns: The columns of Data.x used for computing the distances. E.g.,</span>
+<span class="sd">            Data.x[:,[0,1,2]]</span>
+
+<span class="sd">    Returns:</span>
+<span class="sd">        Returns the same batch of events, but with updated edges.</span>
+<span class="sd">    """</span>
+    <span class="n">data_list</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">to_data_list</span><span class="p">()</span>
+    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data_list</span><span class="p">)):</span>
+        <span class="n">data_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">edge_index</span> <span class="o">=</span> <span class="n">knn_graph</span><span class="p">(</span>
+            <span class="n">x</span><span class="o">=</span><span class="n">data_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="n">columns</span><span class="p">],</span> <span class="n">k</span><span class="o">=</span><span class="n">k</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
+        <span class="p">)</span>
+    <span class="k">return</span> <span class="n">Batch</span><span class="o">.</span><span class="n">from_data_list</span><span class="p">(</span><span class="n">data_list</span><span class="p">)</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/pisa/fitting.html b/_modules/graphnet/pisa/fitting.html
index 8be91cc6e..1865353e4 100644
--- a/_modules/graphnet/pisa/fitting.html
+++ b/_modules/graphnet/pisa/fitting.html
@@ -1169,7 +1169,7 @@ <h1 id="modules-graphnet-pisa-fitting--page-root">Source code for graphnet.pisa.
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/pisa/plotting.html b/_modules/graphnet/pisa/plotting.html
index c1cbdd4a9..6f31c1d90 100644
--- a/_modules/graphnet/pisa/plotting.html
+++ b/_modules/graphnet/pisa/plotting.html
@@ -528,7 +528,7 @@ <h1 id="modules-graphnet-pisa-plotting--page-root">Source code for graphnet.pisa
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/training/callbacks.html b/_modules/graphnet/training/callbacks.html
new file mode 100644
index 000000000..fcf981d92
--- /dev/null
+++ b/_modules/graphnet/training/callbacks.html
@@ -0,0 +1,544 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.training.callbacks &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" />
+    <script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../about.html" />
+    <link rel="index" title="Index" href="../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/training/callbacks" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.training.callbacks </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../"versions.json"",
+        target_loc = "../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-training-callbacks--page-root">Source code for graphnet.training.callbacks</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Callback class(es) for using during model training."""</span>
+
+<span class="kn">import</span> <span class="nn">logging</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span>
+<span class="kn">import</span> <span class="nn">warnings</span>
+
+<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
+<span class="kn">from</span> <span class="nn">tqdm.std</span> <span class="kn">import</span> <span class="n">Bar</span>
+
+<span class="kn">from</span> <span class="nn">pytorch_lightning</span> <span class="kn">import</span> <span class="n">LightningModule</span><span class="p">,</span> <span class="n">Trainer</span>
+<span class="kn">from</span> <span class="nn">pytorch_lightning.callbacks</span> <span class="kn">import</span> <span class="n">TQDMProgressBar</span>
+<span class="kn">from</span> <span class="nn">pytorch_lightning.utilities</span> <span class="kn">import</span> <span class="n">rank_zero_only</span>
+<span class="kn">from</span> <span class="nn">torch.optim</span> <span class="kn">import</span> <span class="n">Optimizer</span>
+<span class="kn">from</span> <span class="nn">torch.optim.lr_scheduler</span> <span class="kn">import</span> <span class="n">_LRScheduler</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span>
+
+
+<div class="viewcode-block" id="PiecewiseLinearLR">
+<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.PiecewiseLinearLR">[docs]</a>
+<span class="k">class</span> <span class="nc">PiecewiseLinearLR</span><span class="p">(</span><span class="n">_LRScheduler</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Interpolate learning rate linearly between milestones."""</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">optimizer</span><span class="p">:</span> <span class="n">Optimizer</span><span class="p">,</span>
+        <span class="n">milestones</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
+        <span class="n">factors</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span>
+        <span class="n">last_epoch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
+        <span class="n">verbose</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct `PiecewiseLinearLR`.</span>
+
+<span class="sd">        For each milestone, denoting a specified number of steps, a factor</span>
+<span class="sd">        multiplying the base learning rate is specified. For steps between two</span>
+<span class="sd">        milestones, the learning rate is interpolated linearly between the two</span>
+<span class="sd">        closest milestones. For steps before the first milestone, the factor</span>
+<span class="sd">        for the first milestone is used; vice versa for steps after the last</span>
+<span class="sd">        milestone.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            optimizer: Wrapped optimizer.</span>
+<span class="sd">            milestones: List of step indices. Must be increasing.</span>
+<span class="sd">            factors: List of multiplicative factors. Must be same length as</span>
+<span class="sd">                `milestones`.</span>
+<span class="sd">            last_epoch: The index of the last epoch.</span>
+<span class="sd">            verbose: If ``True``, prints a message to stdout for each update.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Check(s)</span>
+        <span class="k">if</span> <span class="n">milestones</span> <span class="o">!=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">milestones</span><span class="p">):</span>
+            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Milestones must be increasing"</span><span class="p">)</span>
+        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">milestones</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">factors</span><span class="p">):</span>
+            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
+                <span class="s2">"Only multiplicative factor must be specified for each milestone."</span>
+            <span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">milestones</span> <span class="o">=</span> <span class="n">milestones</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">factors</span> <span class="o">=</span> <span class="n">factors</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">last_epoch</span><span class="p">,</span> <span class="n">verbose</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_get_factor</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
+        <span class="c1"># Linearly interpolate multiplicative factor between milestones.</span>
+        <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">interp</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">last_epoch</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">milestones</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">factors</span><span class="p">)</span>
+
+<div class="viewcode-block" id="PiecewiseLinearLR.get_lr">
+<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.PiecewiseLinearLR.get_lr">[docs]</a>
+    <span class="k">def</span> <span class="nf">get_lr</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Get effective learning rate(s) for each optimizer."""</span>
+        <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_lr_called_within_step</span><span class="p">:</span>
+            <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span>
+                <span class="s2">"To get the last learning rate computed by the scheduler, "</span>
+                <span class="s2">"please use `get_last_lr()`."</span><span class="p">,</span>
+                <span class="ne">UserWarning</span><span class="p">,</span>
+            <span class="p">)</span>
+
+        <span class="k">return</span> <span class="p">[</span><span class="n">base_lr</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_factor</span><span class="p">()</span> <span class="k">for</span> <span class="n">base_lr</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">base_lrs</span><span class="p">]</span></div>
+</div>
+
+
+
+<div class="viewcode-block" id="ProgressBar">
+<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar">[docs]</a>
+<span class="k">class</span> <span class="nc">ProgressBar</span><span class="p">(</span><span class="n">TQDMProgressBar</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Custom progress bar for graphnet.</span>
+
+<span class="sd">    Customises the default progress in pytorch-lightning.</span>
+<span class="sd">    """</span>
+
+    <span class="k">def</span> <span class="nf">_common_config</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">bar</span><span class="p">:</span> <span class="n">Bar</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Bar</span><span class="p">:</span>
+        <span class="n">bar</span><span class="o">.</span><span class="n">unit</span> <span class="o">=</span> <span class="s2">" batch(es)"</span>
+        <span class="n">bar</span><span class="o">.</span><span class="n">colour</span> <span class="o">=</span> <span class="s2">"green"</span>
+        <span class="k">return</span> <span class="n">bar</span>
+
+<div class="viewcode-block" id="ProgressBar.init_validation_tqdm">
+<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_validation_tqdm">[docs]</a>
+    <span class="k">def</span> <span class="nf">init_validation_tqdm</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Bar</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Override for customisation."""</span>
+        <span class="n">bar</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">init_validation_tqdm</span><span class="p">()</span>
+        <span class="n">bar</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_common_config</span><span class="p">(</span><span class="n">bar</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">bar</span></div>
+
+
+<div class="viewcode-block" id="ProgressBar.init_predict_tqdm">
+<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_predict_tqdm">[docs]</a>
+    <span class="k">def</span> <span class="nf">init_predict_tqdm</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Bar</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Override for customisation."""</span>
+        <span class="n">bar</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">init_predict_tqdm</span><span class="p">()</span>
+        <span class="n">bar</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_common_config</span><span class="p">(</span><span class="n">bar</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">bar</span></div>
+
+
+<div class="viewcode-block" id="ProgressBar.init_test_tqdm">
+<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_test_tqdm">[docs]</a>
+    <span class="k">def</span> <span class="nf">init_test_tqdm</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Bar</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Override for customisation."""</span>
+        <span class="n">bar</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">init_test_tqdm</span><span class="p">()</span>
+        <span class="n">bar</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_common_config</span><span class="p">(</span><span class="n">bar</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">bar</span></div>
+
+
+<div class="viewcode-block" id="ProgressBar.init_train_tqdm">
+<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_train_tqdm">[docs]</a>
+    <span class="k">def</span> <span class="nf">init_train_tqdm</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Bar</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Override for customisation."""</span>
+        <span class="n">bar</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">init_train_tqdm</span><span class="p">()</span>
+        <span class="n">bar</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_common_config</span><span class="p">(</span><span class="n">bar</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">bar</span></div>
+
+
+<div class="viewcode-block" id="ProgressBar.get_metrics">
+<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.get_metrics">[docs]</a>
+    <span class="k">def</span> <span class="nf">get_metrics</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">trainer</span><span class="p">:</span> <span class="n">Trainer</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">LightningModule</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Override to not show the version number in the logging."""</span>
+        <span class="n">items</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_metrics</span><span class="p">(</span><span class="n">trainer</span><span class="p">,</span> <span class="n">model</span><span class="p">)</span>
+        <span class="n">items</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">"v_num"</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">items</span></div>
+
+
+<div class="viewcode-block" id="ProgressBar.on_train_epoch_start">
+<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.on_train_epoch_start">[docs]</a>
+    <span class="k">def</span> <span class="nf">on_train_epoch_start</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">trainer</span><span class="p">:</span> <span class="n">Trainer</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">LightningModule</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Print the results of the previous epoch on a separate line.</span>
+
+<span class="sd">        This allows the user to see the losses/metrics for previous epochs</span>
+<span class="sd">        while the current is training. The default behaviour in pytorch-</span>
+<span class="sd">        lightning is to overwrite the progress bar from previous epochs.</span>
+<span class="sd">        """</span>
+        <span class="k">if</span> <span class="n">trainer</span><span class="o">.</span><span class="n">current_epoch</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">train_progress_bar</span><span class="o">.</span><span class="n">set_postfix</span><span class="p">(</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">get_metrics</span><span class="p">(</span><span class="n">trainer</span><span class="p">,</span> <span class="n">model</span><span class="p">)</span>
+            <span class="p">)</span>
+            <span class="nb">print</span><span class="p">(</span><span class="s2">""</span><span class="p">)</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">on_train_epoch_start</span><span class="p">(</span><span class="n">trainer</span><span class="p">,</span> <span class="n">model</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">train_progress_bar</span><span class="o">.</span><span class="n">set_description</span><span class="p">(</span>
+            <span class="sa">f</span><span class="s2">"Epoch </span><span class="si">{</span><span class="n">trainer</span><span class="o">.</span><span class="n">current_epoch</span><span class="si">:</span><span class="s2">2d</span><span class="si">}</span><span class="s2">"</span>
+        <span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="ProgressBar.on_train_epoch_end">
+<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.on_train_epoch_end">[docs]</a>
+    <span class="k">def</span> <span class="nf">on_train_epoch_end</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">trainer</span><span class="p">:</span> <span class="n">Trainer</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">LightningModule</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Log the final progress bar for the epoch to file.</span>
+
+<span class="sd">        Don't duplciate to stdout.</span>
+<span class="sd">        """</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">on_train_epoch_end</span><span class="p">(</span><span class="n">trainer</span><span class="p">,</span> <span class="n">model</span><span class="p">)</span>
+
+        <span class="k">if</span> <span class="n">rank_zero_only</span><span class="o">.</span><span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
+            <span class="c1"># Construct Logger</span>
+            <span class="n">logger</span> <span class="o">=</span> <span class="n">Logger</span><span class="p">()</span>
+
+            <span class="c1"># Log only to file, not stream</span>
+            <span class="n">h</span> <span class="o">=</span> <span class="n">logger</span><span class="o">.</span><span class="n">handlers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
+            <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">logging</span><span class="o">.</span><span class="n">StreamHandler</span><span class="p">)</span>
+            <span class="n">level</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">level</span>
+            <span class="n">h</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><span class="n">logging</span><span class="o">.</span><span class="n">ERROR</span><span class="p">)</span>
+            <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">train_progress_bar</span><span class="p">))</span>
+            <span class="n">h</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span></div>
+</div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/training/labels.html b/_modules/graphnet/training/labels.html
new file mode 100644
index 000000000..b2dd18a77
--- /dev/null
+++ b/_modules/graphnet/training/labels.html
@@ -0,0 +1,436 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.training.labels &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" />
+    <script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../about.html" />
+    <link rel="index" title="Index" href="../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/training/labels" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.training.labels </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../"versions.json"",
+        target_loc = "../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-training-labels--page-root">Source code for graphnet.training.labels</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Class(es) for constructing training labels at runtime."""</span>
+
+<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span>
+
+
+<div class="viewcode-block" id="Label">
+<a class="viewcode-back" href="../../../api/graphnet.training.labels.html#graphnet.training.labels.Label">[docs]</a>
+<span class="k">class</span> <span class="nc">Label</span><span class="p">(</span><span class="n">ABC</span><span class="p">,</span> <span class="n">Logger</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Base `Label` class for producing labels from single `Data` instance."""</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct `Label`.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            key: The name of the field in `Data` where the label will be</span>
+<span class="sd">                stored. That is, `graph[key] = label`.</span>
+<span class="sd">        """</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_key</span> <span class="o">=</span> <span class="n">key</span>
+
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>
+
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">key</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Return value of `key`."""</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_key</span>
+
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Label-specific implementation."""</span></div>
+
+
+
+<div class="viewcode-block" id="Direction">
+<a class="viewcode-back" href="../../../api/graphnet.training.labels.html#graphnet.training.labels.Direction">[docs]</a>
+<span class="k">class</span> <span class="nc">Direction</span><span class="p">(</span><span class="n">Label</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Class for producing particle direction/pointing label."""</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">key</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"direction"</span><span class="p">,</span>
+        <span class="n">azimuth_key</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"azimuth"</span><span class="p">,</span>
+        <span class="n">zenith_key</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"zenith"</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct `Direction`.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            key: The name of the field in `Data` where the label will be</span>
+<span class="sd">                stored. That is, `graph[key] = label`.</span>
+<span class="sd">            azimuth_key: The name of the pre-existing key in `graph` that will</span>
+<span class="sd">                be used to access the azimiuth angle, used when calculating</span>
+<span class="sd">                the direction.</span>
+<span class="sd">            zenith_key: The name of the pre-existing key in `graph` that will</span>
+<span class="sd">                be used to access the zenith angle, used when calculating the</span>
+<span class="sd">                direction.</span>
+<span class="sd">        """</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_azimuth_key</span> <span class="o">=</span> <span class="n">azimuth_key</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_zenith_key</span> <span class="o">=</span> <span class="n">zenith_key</span>
+
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">key</span><span class="o">=</span><span class="n">key</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Compute label for `graph`."""</span>
+        <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">graph</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_azimuth_key</span><span class="p">])</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span>
+            <span class="n">graph</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_zenith_key</span><span class="p">]</span>
+        <span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
+        <span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">graph</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_azimuth_key</span><span class="p">])</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span>
+            <span class="n">graph</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_zenith_key</span><span class="p">]</span>
+        <span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
+        <span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">graph</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_zenith_key</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">z</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/training/loss_functions.html b/_modules/graphnet/training/loss_functions.html
new file mode 100644
index 000000000..dd70fd272
--- /dev/null
+++ b/_modules/graphnet/training/loss_functions.html
@@ -0,0 +1,859 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.training.loss_functions &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" />
+    <script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../about.html" />
+    <link rel="index" title="Index" href="../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/training/loss_functions" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.training.loss_functions </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../"versions.json"",
+        target_loc = "../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-training-loss-functions--page-root">Source code for graphnet.training.loss_functions</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Collection of loss functions.</span>
+
+<span class="sd">All loss functions inherit from `LossFunction` which ensures a common syntax,</span>
+<span class="sd">handles per-event weights, etc.</span>
+<span class="sd">"""</span>
+
+<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span>
+
+<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
+<span class="kn">import</span> <span class="nn">scipy.special</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span>
+<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
+<span class="kn">from</span> <span class="nn">torch.nn.functional</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">one_hot</span><span class="p">,</span>
+    <span class="n">cross_entropy</span><span class="p">,</span>
+    <span class="n">binary_cross_entropy</span><span class="p">,</span>
+    <span class="n">softplus</span><span class="p">,</span>
+<span class="p">)</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.model</span> <span class="kn">import</span> <span class="n">Model</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.decorators</span> <span class="kn">import</span> <span class="n">final</span>
+
+
+<div class="viewcode-block" id="LossFunction">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction">[docs]</a>
+<span class="k">class</span> <span class="nc">LossFunction</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Base class for loss functions in `graphnet`."""</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct `LossFunction`, saving model config."""</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+
+<div class="viewcode-block" id="LossFunction.forward">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction.forward">[docs]</a>
+    <span class="nd">@final</span>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span>  <span class="c1"># type: ignore[override]</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">weights</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">return_elements</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Forward pass for all loss functions.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            prediction: Tensor containing predictions. Shape [N,P]</span>
+<span class="sd">            target: Tensor containing targets. Shape [N,T]</span>
+<span class="sd">            return_elements: Whether elementwise loss terms should be returned.</span>
+<span class="sd">                The alternative is to return the averaged loss across examples.</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            Loss, either averaged to a scalar (if `return_elements = False`) or</span>
+<span class="sd">            elementwise terms with shape [N,] (if `return_elements = True`).</span>
+<span class="sd">        """</span>
+        <span class="n">elements</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_forward</span><span class="p">(</span><span class="n">prediction</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">weights</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">elements</span> <span class="o">=</span> <span class="n">elements</span> <span class="o">*</span> <span class="n">weights</span>
+        <span class="k">assert</span> <span class="n">elements</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="n">target</span><span class="o">.</span><span class="n">size</span><span class="p">(</span>
+            <span class="n">dim</span><span class="o">=</span><span class="mi">0</span>
+        <span class="p">),</span> <span class="s2">"`_forward` should return elementwise loss terms."</span>
+
+        <span class="k">return</span> <span class="n">elements</span> <span class="k">if</span> <span class="n">return_elements</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">elements</span><span class="p">)</span></div>
+
+
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Syntax like `.forward`, for implentation in inheriting classes."""</span></div>
+
+
+
+<div class="viewcode-block" id="MSELoss">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.MSELoss">[docs]</a>
+<span class="k">class</span> <span class="nc">MSELoss</span><span class="p">(</span><span class="n">LossFunction</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Mean squared error loss."""</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Implement loss calculation."""</span>
+        <span class="c1"># Check(s)</span>
+        <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span>
+        <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">size</span><span class="p">()</span> <span class="o">==</span> <span class="n">target</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
+
+        <span class="n">elements</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">((</span><span class="n">prediction</span> <span class="o">-</span> <span class="n">target</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">elements</span></div>
+
+
+
+<div class="viewcode-block" id="RMSELoss">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.RMSELoss">[docs]</a>
+<span class="k">class</span> <span class="nc">RMSELoss</span><span class="p">(</span><span class="n">MSELoss</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Root mean squared error loss."""</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Implement loss calculation."""</span>
+        <span class="c1"># Check(s)</span>
+        <span class="n">elements</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_forward</span><span class="p">(</span><span class="n">prediction</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
+        <span class="n">elements</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">elements</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">elements</span></div>
+
+
+
+<div class="viewcode-block" id="LogCoshLoss">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCoshLoss">[docs]</a>
+<span class="k">class</span> <span class="nc">LogCoshLoss</span><span class="p">(</span><span class="n">LossFunction</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Log-cosh loss function.</span>
+
+<span class="sd">    Acts like x^2 for small x; and like |x| for large x.</span>
+<span class="sd">    """</span>
+
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">_log_cosh</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>  <span class="c1"># pylint: disable=invalid-name</span>
+<span class="w">        </span><span class="sd">"""Numerically stable version on log(cosh(x)).</span>
+
+<span class="sd">        Used to avoid `inf` for even moderately large differences.</span>
+<span class="sd">        See [https://github.com/keras-team/keras/blob/v2.6.0/keras/losses.py#L1580-L1617]</span>
+<span class="sd">        """</span>
+        <span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="n">softplus</span><span class="p">(</span><span class="o">-</span><span class="mf">2.0</span> <span class="o">*</span> <span class="n">x</span><span class="p">)</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mf">2.0</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Implement loss calculation."""</span>
+        <span class="n">diff</span> <span class="o">=</span> <span class="n">prediction</span> <span class="o">-</span> <span class="n">target</span>
+        <span class="n">elements</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_log_cosh</span><span class="p">(</span><span class="n">diff</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">elements</span></div>
+
+
+
+<div class="viewcode-block" id="CrossEntropyLoss">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.CrossEntropyLoss">[docs]</a>
+<span class="k">class</span> <span class="nc">CrossEntropyLoss</span><span class="p">(</span><span class="n">LossFunction</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Compute cross-entropy loss for classification tasks.</span>
+
+<span class="sd">    Predictions are an [N, num_class]-matrix of logits (i.e., non-softmax'ed</span>
+<span class="sd">    probabilities), and targets are an [N,1]-matrix with integer values in</span>
+<span class="sd">    (0, num_classes - 1).</span>
+<span class="sd">    """</span>
+
+    <span class="nd">@save_model_config</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">options</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="n">Any</span><span class="p">],</span> <span class="n">Dict</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="nb">int</span><span class="p">]],</span>
+        <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span>
+        <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span>
+    <span class="p">):</span>
+<span class="w">        </span><span class="sd">"""Construct CrossEntropyLoss."""</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+
+        <span class="c1"># Member variables</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_options</span> <span class="o">=</span> <span class="n">options</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_nb_classes</span><span class="p">:</span> <span class="nb">int</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
+            <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_options</span> <span class="ow">in</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">]</span>
+            <span class="k">assert</span> <span class="p">(</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_options</span> <span class="o">&gt;=</span> <span class="mi">2</span>
+            <span class="p">),</span> <span class="sa">f</span><span class="s2">"Minimum of two classes required. Got </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="si">}</span><span class="s2">."</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_nb_classes</span> <span class="o">=</span> <span class="n">options</span>  <span class="c1"># type: ignore</span>
+        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_nb_classes</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">)</span>  <span class="c1"># type: ignore</span>
+        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_nb_classes</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span>
+                <span class="n">np</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="o">.</span><span class="n">values</span><span class="p">()))</span>
+            <span class="p">)</span>  <span class="c1"># type: ignore</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
+                <span class="sa">f</span><span class="s2">"Class options of type </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">)</span><span class="si">}</span><span class="s2"> not supported"</span>
+            <span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s2">"none"</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Transform outputs to angle and prepare prediction."""</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
+            <span class="c1"># Integer number of classes: Targets are expected to be in</span>
+            <span class="c1"># (0, nb_classes - 1).</span>
+
+            <span class="c1"># Target integers are positive</span>
+            <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">target</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span>
+
+            <span class="c1"># Target integers are consistent with the expected number of class.</span>
+            <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">target</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">)</span>
+
+            <span class="k">assert</span> <span class="n">target</span><span class="o">.</span><span class="n">dtype</span> <span class="ow">in</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">]</span>
+            <span class="n">target_integer</span> <span class="o">=</span> <span class="n">target</span>
+
+        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
+            <span class="c1"># List of classes: Mapping target classes in list onto</span>
+            <span class="c1"># (0, nb_classes - 1). Example:</span>
+            <span class="c1">#    Given options: [1, 12, 13, ...]</span>
+            <span class="c1">#    Yields: [1, 13, 12] -&gt; [0, 2, 1, ...]</span>
+            <span class="n">target_integer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
+                <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="n">value</span><span class="p">)</span> <span class="k">for</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">target</span><span class="p">]</span>
+            <span class="p">)</span>
+
+        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
+            <span class="c1"># Dictionary of classes: Mapping target classes in dict onto</span>
+            <span class="c1"># (0, nb_classes - 1). Example:</span>
+            <span class="c1">#     Given options: {1: 0, -1: 0, 12: 1, -12: 1, ...}</span>
+            <span class="c1">#     Yields: [1, -1, -12, ...] -&gt; [0, 0, 1, ...]</span>
+            <span class="n">target_integer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
+                <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="n">value</span><span class="p">)]</span> <span class="k">for</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">target</span><span class="p">]</span>
+            <span class="p">)</span>
+
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">"Shouldn't reach here."</span>
+
+        <span class="n">target_one_hot</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="n">one_hot</span><span class="p">(</span><span class="n">target_integer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_classes</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
+            <span class="n">prediction</span><span class="o">.</span><span class="n">device</span>
+        <span class="p">)</span>
+
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_loss</span><span class="p">(</span><span class="n">prediction</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">target_one_hot</span><span class="o">.</span><span class="n">float</span><span class="p">())</span></div>
+
+
+
+<div class="viewcode-block" id="BinaryCrossEntropyLoss">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.BinaryCrossEntropyLoss">[docs]</a>
+<span class="k">class</span> <span class="nc">BinaryCrossEntropyLoss</span><span class="p">(</span><span class="n">LossFunction</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Compute binary cross entropy loss.</span>
+
+<span class="sd">    Predictions are vector probabilities (i.e., values between 0 and 1), and</span>
+<span class="sd">    targets should be 0 and 1.</span>
+<span class="sd">    """</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">binary_cross_entropy</span><span class="p">(</span>
+            <span class="n">prediction</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">target</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s2">"none"</span>
+        <span class="p">)</span></div>
+
+
+
+<div class="viewcode-block" id="LogCMK">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCMK">[docs]</a>
+<span class="k">class</span> <span class="nc">LogCMK</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""MIT License.</span>
+
+<span class="sd">    Copyright (c) 2019 Max Ryabinin</span>
+
+<span class="sd">    Permission is hereby granted, free of charge, to any person obtaining a copy</span>
+<span class="sd">    of this software and associated documentation files (the "Software"), to deal</span>
+<span class="sd">    in the Software without restriction, including without limitation the rights</span>
+<span class="sd">    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell</span>
+<span class="sd">    copies of the Software, and to permit persons to whom the Software is</span>
+<span class="sd">    furnished to do so, subject to the following conditions:</span>
+
+<span class="sd">    The above copyright notice and this permission notice shall be included in all</span>
+<span class="sd">    copies or substantial portions of the Software.</span>
+
+<span class="sd">    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR</span>
+<span class="sd">    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,</span>
+<span class="sd">    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE</span>
+<span class="sd">    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER</span>
+<span class="sd">    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,</span>
+<span class="sd">    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE</span>
+<span class="sd">    SOFTWARE.</span>
+<span class="sd">    _____________________</span>
+
+<span class="sd">    From [https://github.com/mryab/vmf_loss/blob/master/losses.py]</span>
+<span class="sd">    Modified to use modified Bessel function instead of exponentially scaled ditto</span>
+<span class="sd">    (i.e. `.ive` -&gt; `.iv`) as indiciated in [1812.04616] in spite of suggestion in</span>
+<span class="sd">    Sec. 8.2 of this paper. The change has been validated through comparison with</span>
+<span class="sd">    exact calculations for `m=2` and `m=3` and found to yield the correct results.</span>
+<span class="sd">    """</span>
+
+<div class="viewcode-block" id="LogCMK.forward">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCMK.forward">[docs]</a>
+    <span class="nd">@staticmethod</span>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span>
+        <span class="n">ctx</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">m</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">kappa</span><span class="p">:</span> <span class="n">Tensor</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>  <span class="c1"># pylint: disable=invalid-name,arguments-differ</span>
+<span class="w">        </span><span class="sd">"""Forward pass."""</span>
+        <span class="n">dtype</span> <span class="o">=</span> <span class="n">kappa</span><span class="o">.</span><span class="n">dtype</span>
+        <span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">kappa</span><span class="p">)</span>
+        <span class="n">ctx</span><span class="o">.</span><span class="n">m</span> <span class="o">=</span> <span class="n">m</span>
+        <span class="n">ctx</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
+        <span class="n">kappa</span> <span class="o">=</span> <span class="n">kappa</span><span class="o">.</span><span class="n">double</span><span class="p">()</span>
+        <span class="n">iv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span>
+            <span class="n">scipy</span><span class="o">.</span><span class="n">special</span><span class="o">.</span><span class="n">iv</span><span class="p">(</span><span class="n">m</span> <span class="o">/</span> <span class="mf">2.0</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">kappa</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
+        <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">kappa</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
+        <span class="k">return</span> <span class="p">(</span>
+            <span class="p">(</span><span class="n">m</span> <span class="o">/</span> <span class="mf">2.0</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">kappa</span><span class="p">)</span>
+            <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">iv</span><span class="p">)</span>
+            <span class="o">-</span> <span class="p">(</span><span class="n">m</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">pi</span><span class="p">)</span>
+        <span class="p">)</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="LogCMK.backward">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCMK.backward">[docs]</a>
+    <span class="nd">@staticmethod</span>
+    <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span>
+        <span class="n">ctx</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">grad_output</span><span class="p">:</span> <span class="n">Tensor</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>  <span class="c1"># pylint: disable=invalid-name,arguments-differ</span>
+<span class="w">        </span><span class="sd">"""Backward pass."""</span>
+        <span class="n">kappa</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
+        <span class="n">m</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">m</span>
+        <span class="n">dtype</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">dtype</span>
+        <span class="n">kappa</span> <span class="o">=</span> <span class="n">kappa</span><span class="o">.</span><span class="n">double</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
+        <span class="n">grads</span> <span class="o">=</span> <span class="o">-</span><span class="p">(</span>
+            <span class="p">(</span><span class="n">scipy</span><span class="o">.</span><span class="n">special</span><span class="o">.</span><span class="n">iv</span><span class="p">(</span><span class="n">m</span> <span class="o">/</span> <span class="mf">2.0</span><span class="p">,</span> <span class="n">kappa</span><span class="p">))</span>
+            <span class="o">/</span> <span class="p">(</span><span class="n">scipy</span><span class="o">.</span><span class="n">special</span><span class="o">.</span><span class="n">iv</span><span class="p">(</span><span class="n">m</span> <span class="o">/</span> <span class="mf">2.0</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">kappa</span><span class="p">))</span>
+        <span class="p">)</span>
+        <span class="k">return</span> <span class="p">(</span>
+            <span class="kc">None</span><span class="p">,</span>
+            <span class="n">grad_output</span>
+            <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">grads</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">grad_output</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span>
+        <span class="p">)</span></div>
+</div>
+
+
+
+<div class="viewcode-block" id="VonMisesFisherLoss">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss">[docs]</a>
+<span class="k">class</span> <span class="nc">VonMisesFisherLoss</span><span class="p">(</span><span class="n">LossFunction</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""General class for calculating von Mises-Fisher loss.</span>
+
+<span class="sd">    Requires implementation for specific dimension `m` in which the target and</span>
+<span class="sd">    prediction vectors need to be prepared.</span>
+<span class="sd">    """</span>
+
+<div class="viewcode-block" id="VonMisesFisherLoss.log_cmk_exact">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact">[docs]</a>
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">log_cmk_exact</span><span class="p">(</span>
+        <span class="bp">cls</span><span class="p">,</span> <span class="n">m</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">kappa</span><span class="p">:</span> <span class="n">Tensor</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>  <span class="c1"># pylint: disable=invalid-name</span>
+<span class="w">        </span><span class="sd">"""Calculate $log C_{m}(k)$ term in von Mises-Fisher loss exactly."""</span>
+        <span class="k">return</span> <span class="n">LogCMK</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">kappa</span><span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="VonMisesFisherLoss.log_cmk_approx">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx">[docs]</a>
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">log_cmk_approx</span><span class="p">(</span>
+        <span class="bp">cls</span><span class="p">,</span> <span class="n">m</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">kappa</span><span class="p">:</span> <span class="n">Tensor</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>  <span class="c1"># pylint: disable=invalid-name</span>
+<span class="w">        </span><span class="sd">"""Calculate $log C_{m}(k)$ term in von Mises-Fisher loss approx.</span>
+
+<span class="sd">        [https://arxiv.org/abs/1812.04616] Sec. 8.2 with additional minus sign.</span>
+<span class="sd">        """</span>
+        <span class="n">v</span> <span class="o">=</span> <span class="n">m</span> <span class="o">/</span> <span class="mf">2.0</span> <span class="o">-</span> <span class="mf">0.5</span>
+        <span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">((</span><span class="n">v</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="n">kappa</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span>
+        <span class="n">b</span> <span class="o">=</span> <span class="n">v</span> <span class="o">-</span> <span class="mi">1</span>
+        <span class="k">return</span> <span class="o">-</span><span class="n">a</span> <span class="o">+</span> <span class="n">b</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">b</span> <span class="o">+</span> <span class="n">a</span><span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="VonMisesFisherLoss.log_cmk">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk">[docs]</a>
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">log_cmk</span><span class="p">(</span>
+        <span class="bp">cls</span><span class="p">,</span> <span class="n">m</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">kappa</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">kappa_switch</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">100.0</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>  <span class="c1"># pylint: disable=invalid-name</span>
+<span class="w">        </span><span class="sd">"""Calculate $log C_{m}(k)$ term in von Mises-Fisher loss.</span>
+
+<span class="sd">        Since `log_cmk_exact` is diverges for `kappa` &gt;~ 700 (using float64</span>
+<span class="sd">        precision), and since `log_cmk_approx` is unaccurate for small `kappa`,</span>
+<span class="sd">        this method automatically switches between the two at `kappa_switch`,</span>
+<span class="sd">        ensuring continuity at this point.</span>
+<span class="sd">        """</span>
+        <span class="n">kappa_switch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">kappa_switch</span><span class="p">])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">kappa</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
+        <span class="n">mask_exact</span> <span class="o">=</span> <span class="n">kappa</span> <span class="o">&lt;</span> <span class="n">kappa_switch</span>
+
+        <span class="c1"># Ensure continuity at `kappa_switch`</span>
+        <span class="n">offset</span> <span class="o">=</span> <span class="bp">cls</span><span class="o">.</span><span class="n">log_cmk_approx</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">kappa_switch</span><span class="p">)</span> <span class="o">-</span> <span class="bp">cls</span><span class="o">.</span><span class="n">log_cmk_exact</span><span class="p">(</span>
+            <span class="n">m</span><span class="p">,</span> <span class="n">kappa_switch</span>
+        <span class="p">)</span>
+        <span class="n">ret</span> <span class="o">=</span> <span class="bp">cls</span><span class="o">.</span><span class="n">log_cmk_approx</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">kappa</span><span class="p">)</span> <span class="o">-</span> <span class="n">offset</span>
+        <span class="n">ret</span><span class="p">[</span><span class="n">mask_exact</span><span class="p">]</span> <span class="o">=</span> <span class="bp">cls</span><span class="o">.</span><span class="n">log_cmk_exact</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">kappa</span><span class="p">[</span><span class="n">mask_exact</span><span class="p">])</span>
+        <span class="k">return</span> <span class="n">ret</span></div>
+
+
+    <span class="k">def</span> <span class="nf">_evaluate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Calculate von Mises-Fisher loss for a vector in D dimensons.</span>
+
+<span class="sd">        This loss utilises the von Mises-Fisher distribution, which is a</span>
+<span class="sd">        probability distribution on the (D - 1) sphere in D-dimensional space.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            prediction: Predicted vector, of shape [batch_size, D].</span>
+<span class="sd">            target: Target unit vector, of shape [batch_size, D].</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            Elementwise von Mises-Fisher loss terms.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Check(s)</span>
+        <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span>
+        <span class="k">assert</span> <span class="n">target</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span>
+        <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">size</span><span class="p">()</span> <span class="o">==</span> <span class="n">target</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
+
+        <span class="c1"># Computing loss</span>
+        <span class="n">m</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
+        <span class="n">k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">prediction</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+        <span class="n">dotprod</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">prediction</span> <span class="o">*</span> <span class="n">target</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+        <span class="n">elements</span> <span class="o">=</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">log_cmk</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">-</span> <span class="n">dotprod</span>
+        <span class="k">return</span> <span class="n">elements</span>
+
+    <span class="nd">@abstractmethod</span>
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+        <span class="k">raise</span> <span class="ne">NotImplementedError</span></div>
+
+
+
+<div class="viewcode-block" id="VonMisesFisher2DLoss">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisher2DLoss">[docs]</a>
+<span class="k">class</span> <span class="nc">VonMisesFisher2DLoss</span><span class="p">(</span><span class="n">VonMisesFisherLoss</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""von Mises-Fisher loss function vectors in the 2D plane."""</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Calculate von Mises-Fisher loss for an angle in the 2D plane.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            prediction: Output of the model. Must have shape [N, 2] where 0th</span>
+<span class="sd">                column is a prediction of `angle` and 1st column is an estimate</span>
+<span class="sd">                of `kappa`.</span>
+<span class="sd">            target: Target tensor, extracted from graph object.</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            loss: Elementwise von Mises-Fisher loss terms. Shape [N,]</span>
+<span class="sd">        """</span>
+        <span class="c1"># Check(s)</span>
+        <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span> <span class="ow">and</span> <span class="n">prediction</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">2</span>
+        <span class="k">assert</span> <span class="n">target</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span>
+        <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">target</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
+
+        <span class="c1"># Formatting target</span>
+        <span class="n">angle_true</span> <span class="o">=</span> <span class="n">target</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span>
+        <span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span>
+            <span class="p">[</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">angle_true</span><span class="p">),</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">angle_true</span><span class="p">),</span>
+            <span class="p">],</span>
+            <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Formatting prediction</span>
+        <span class="n">angle_pred</span> <span class="o">=</span> <span class="n">prediction</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span>
+        <span class="n">kappa</span> <span class="o">=</span> <span class="n">prediction</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span>
+        <span class="n">p</span> <span class="o">=</span> <span class="n">kappa</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span>
+            <span class="p">[</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">angle_pred</span><span class="p">),</span>
+                <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">angle_pred</span><span class="p">),</span>
+            <span class="p">],</span>
+            <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_evaluate</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></div>
+
+
+
+<div class="viewcode-block" id="EuclideanDistanceLoss">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.EuclideanDistanceLoss">[docs]</a>
+<span class="k">class</span> <span class="nc">EuclideanDistanceLoss</span><span class="p">(</span><span class="n">LossFunction</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Mean squared error in three dimensions."""</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Calculate 3D Euclidean distance between predicted and target.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            prediction: Output of the model. Must have shape [N, 3]</span>
+<span class="sd">            target: Target tensor, extracted from graph object.</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            Elementwise von Mises-Fisher loss terms. Shape [N,]</span>
+<span class="sd">        """</span>
+        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span>
+            <span class="p">(</span><span class="n">prediction</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">target</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span> <span class="o">**</span> <span class="mi">2</span>
+            <span class="o">+</span> <span class="p">(</span><span class="n">prediction</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">target</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span> <span class="o">**</span> <span class="mi">2</span>
+            <span class="o">+</span> <span class="p">(</span><span class="n">prediction</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">target</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">])</span> <span class="o">**</span> <span class="mi">2</span>
+        <span class="p">)</span></div>
+
+
+
+<div class="viewcode-block" id="VonMisesFisher3DLoss">
+<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisher3DLoss">[docs]</a>
+<span class="k">class</span> <span class="nc">VonMisesFisher3DLoss</span><span class="p">(</span><span class="n">VonMisesFisherLoss</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""von Mises-Fisher loss function vectors in the 3D plane."""</span>
+
+    <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Calculate von Mises-Fisher loss for a direction in the 3D.</span>
+
+<span class="sd">        Args:</span>
+<span class="sd">            prediction: Output of the model. Must have shape [N, 4] where</span>
+<span class="sd">                columns 0, 1, 2 are predictions of `direction` and last column</span>
+<span class="sd">                is an estimate of `kappa`.</span>
+<span class="sd">            target: Target tensor, extracted from graph object.</span>
+
+<span class="sd">        Returns:</span>
+<span class="sd">            Elementwise von Mises-Fisher loss terms. Shape [N,]</span>
+<span class="sd">        """</span>
+        <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
+        <span class="c1"># Check(s)</span>
+        <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span> <span class="ow">and</span> <span class="n">prediction</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">4</span>
+        <span class="k">assert</span> <span class="n">target</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span>
+        <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">target</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
+
+        <span class="n">kappa</span> <span class="o">=</span> <span class="n">prediction</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span>
+        <span class="n">p</span> <span class="o">=</span> <span class="n">kappa</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">prediction</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_evaluate</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/training/utils.html b/_modules/graphnet/training/utils.html
new file mode 100644
index 000000000..eff87f698
--- /dev/null
+++ b/_modules/graphnet/training/utils.html
@@ -0,0 +1,656 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.training.utils &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" />
+    <script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../about.html" />
+    <link rel="index" title="Index" href="../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/training/utils" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.training.utils </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../"versions.json"",
+        target_loc = "../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-training-utils--page-root">Source code for graphnet.training.utils</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Utility functions for `graphnet.training`."""</span>
+
+<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">OrderedDict</span>
+<span class="kn">import</span> <span class="nn">os</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Callable</span>
+
+<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
+<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span>
+<span class="kn">from</span> <span class="nn">pytorch_lightning</span> <span class="kn">import</span> <span class="n">Trainer</span>
+<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span>
+<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span>
+<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Batch</span><span class="p">,</span> <span class="n">Data</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.data.dataset</span> <span class="kn">import</span> <span class="n">Dataset</span>
+<span class="kn">from</span> <span class="nn">graphnet.data.dataset</span> <span class="kn">import</span> <span class="n">SQLiteDataset</span>
+<span class="kn">from</span> <span class="nn">graphnet.data.dataset</span> <span class="kn">import</span> <span class="n">ParquetDataset</span>
+<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span>
+<span class="kn">from</span> <span class="nn">graphnet.models.graphs</span> <span class="kn">import</span> <span class="n">GraphDefinition</span>
+
+
+<div class="viewcode-block" id="collate_fn">
+<a class="viewcode-back" href="../../../api/graphnet.training.utils.html#graphnet.training.utils.collate_fn">[docs]</a>
+<span class="k">def</span> <span class="nf">collate_fn</span><span class="p">(</span><span class="n">graphs</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Data</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Batch</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Remove graphs with less than two DOM hits.</span>
+
+<span class="sd">    Should not occur in "production.</span>
+<span class="sd">    """</span>
+    <span class="n">graphs</span> <span class="o">=</span> <span class="p">[</span><span class="n">g</span> <span class="k">for</span> <span class="n">g</span> <span class="ow">in</span> <span class="n">graphs</span> <span class="k">if</span> <span class="n">g</span><span class="o">.</span><span class="n">n_pulses</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">]</span>
+    <span class="k">return</span> <span class="n">Batch</span><span class="o">.</span><span class="n">from_data_list</span><span class="p">(</span><span class="n">graphs</span><span class="p">)</span></div>
+
+
+
+<span class="c1"># @TODO: Remove in favour of DataLoader{,.from_dataset_config}</span>
+<div class="viewcode-block" id="make_dataloader">
+<a class="viewcode-back" href="../../../api/graphnet.training.utils.html#graphnet.training.utils.make_dataloader">[docs]</a>
+<span class="k">def</span> <span class="nf">make_dataloader</span><span class="p">(</span>
+    <span class="n">db</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+    <span class="n">pulsemaps</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span>
+    <span class="n">graph_definition</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">GraphDefinition</span><span class="p">],</span>
+    <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+    <span class="n">truth</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+    <span class="o">*</span><span class="p">,</span>
+    <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+    <span class="n">shuffle</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
+    <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="n">num_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
+    <span class="n">persistent_workers</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
+    <span class="n">node_truth</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="n">truth_table</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"truth"</span><span class="p">,</span>
+    <span class="n">node_truth_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="n">string_selection</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="n">loss_weight_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="n">loss_weight_column</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="n">index_column</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"event_no"</span><span class="p">,</span>
+    <span class="n">labels</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataLoader</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Construct `DataLoader` instance."""</span>
+    <span class="c1"># Check(s)</span>
+    <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pulsemaps</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
+        <span class="n">pulsemaps</span> <span class="o">=</span> <span class="p">[</span><span class="n">pulsemaps</span><span class="p">]</span>
+
+    <span class="n">dataset</span> <span class="o">=</span> <span class="n">SQLiteDataset</span><span class="p">(</span>
+        <span class="n">path</span><span class="o">=</span><span class="n">db</span><span class="p">,</span>
+        <span class="n">pulsemaps</span><span class="o">=</span><span class="n">pulsemaps</span><span class="p">,</span>
+        <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">,</span>
+        <span class="n">truth</span><span class="o">=</span><span class="n">truth</span><span class="p">,</span>
+        <span class="n">selection</span><span class="o">=</span><span class="n">selection</span><span class="p">,</span>
+        <span class="n">node_truth</span><span class="o">=</span><span class="n">node_truth</span><span class="p">,</span>
+        <span class="n">truth_table</span><span class="o">=</span><span class="n">truth_table</span><span class="p">,</span>
+        <span class="n">node_truth_table</span><span class="o">=</span><span class="n">node_truth_table</span><span class="p">,</span>
+        <span class="n">string_selection</span><span class="o">=</span><span class="n">string_selection</span><span class="p">,</span>
+        <span class="n">loss_weight_table</span><span class="o">=</span><span class="n">loss_weight_table</span><span class="p">,</span>
+        <span class="n">loss_weight_column</span><span class="o">=</span><span class="n">loss_weight_column</span><span class="p">,</span>
+        <span class="n">index_column</span><span class="o">=</span><span class="n">index_column</span><span class="p">,</span>
+        <span class="n">graph_definition</span><span class="o">=</span><span class="n">graph_definition</span><span class="p">,</span>
+    <span class="p">)</span>
+
+    <span class="c1"># adds custom labels to dataset</span>
+    <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
+        <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">labels</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+            <span class="n">dataset</span><span class="o">.</span><span class="n">add_label</span><span class="p">(</span><span class="n">key</span><span class="o">=</span><span class="n">label</span><span class="p">,</span> <span class="n">fn</span><span class="o">=</span><span class="n">labels</span><span class="p">[</span><span class="n">label</span><span class="p">])</span>
+
+    <span class="n">dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span>
+        <span class="n">dataset</span><span class="p">,</span>
+        <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
+        <span class="n">shuffle</span><span class="o">=</span><span class="n">shuffle</span><span class="p">,</span>
+        <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
+        <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">,</span>
+        <span class="n">persistent_workers</span><span class="o">=</span><span class="n">persistent_workers</span><span class="p">,</span>
+        <span class="n">prefetch_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
+    <span class="p">)</span>
+
+    <span class="k">return</span> <span class="n">dataloader</span></div>
+
+
+
+<span class="c1"># @TODO: Remove in favour of DataLoader{,.from_dataset_config}</span>
+<div class="viewcode-block" id="make_train_validation_dataloader">
+<a class="viewcode-back" href="../../../api/graphnet.training.utils.html#graphnet.training.utils.make_train_validation_dataloader">[docs]</a>
+<span class="k">def</span> <span class="nf">make_train_validation_dataloader</span><span class="p">(</span>
+    <span class="n">db</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+    <span class="n">graph_definition</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">GraphDefinition</span><span class="p">],</span>
+    <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
+    <span class="n">pulsemaps</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span>
+    <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+    <span class="n">truth</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+    <span class="o">*</span><span class="p">,</span>
+    <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+    <span class="n">database_indices</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="n">seed</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">42</span><span class="p">,</span>
+    <span class="n">test_size</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.33</span><span class="p">,</span>
+    <span class="n">num_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
+    <span class="n">persistent_workers</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
+    <span class="n">node_truth</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="n">truth_table</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"truth"</span><span class="p">,</span>
+    <span class="n">node_truth_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="n">string_selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="n">loss_weight_column</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="n">loss_weight_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="n">index_column</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"event_no"</span><span class="p">,</span>
+    <span class="n">labels</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">DataLoader</span><span class="p">,</span> <span class="n">DataLoader</span><span class="p">]:</span>
+<span class="w">    </span><span class="sd">"""Construct train and test `DataLoader` instances."""</span>
+    <span class="c1"># Reproducibility</span>
+    <span class="n">rng</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">default_rng</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">)</span>
+    <span class="c1"># Checks(s)</span>
+    <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pulsemaps</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
+        <span class="n">pulsemaps</span> <span class="o">=</span> <span class="p">[</span><span class="n">pulsemaps</span><span class="p">]</span>
+
+    <span class="k">if</span> <span class="n">selection</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="c1"># If no selection is provided, use all events in dataset.</span>
+        <span class="n">dataset</span><span class="p">:</span> <span class="n">Dataset</span>
+        <span class="k">if</span> <span class="n">db</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s2">".db"</span><span class="p">):</span>
+            <span class="n">dataset</span> <span class="o">=</span> <span class="n">SQLiteDataset</span><span class="p">(</span>
+                <span class="n">path</span><span class="o">=</span><span class="n">db</span><span class="p">,</span>
+                <span class="n">graph_definition</span><span class="o">=</span><span class="n">graph_definition</span><span class="p">,</span>
+                <span class="n">pulsemaps</span><span class="o">=</span><span class="n">pulsemaps</span><span class="p">,</span>
+                <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">,</span>
+                <span class="n">truth</span><span class="o">=</span><span class="n">truth</span><span class="p">,</span>
+                <span class="n">truth_table</span><span class="o">=</span><span class="n">truth_table</span><span class="p">,</span>
+                <span class="n">index_column</span><span class="o">=</span><span class="n">index_column</span><span class="p">,</span>
+            <span class="p">)</span>
+        <span class="k">elif</span> <span class="n">db</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s2">".parquet"</span><span class="p">):</span>
+            <span class="n">dataset</span> <span class="o">=</span> <span class="n">ParquetDataset</span><span class="p">(</span>
+                <span class="n">path</span><span class="o">=</span><span class="n">db</span><span class="p">,</span>
+                <span class="n">graph_definition</span><span class="o">=</span><span class="n">graph_definition</span><span class="p">,</span>
+                <span class="n">pulsemaps</span><span class="o">=</span><span class="n">pulsemaps</span><span class="p">,</span>
+                <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">,</span>
+                <span class="n">truth</span><span class="o">=</span><span class="n">truth</span><span class="p">,</span>
+                <span class="n">truth_table</span><span class="o">=</span><span class="n">truth_table</span><span class="p">,</span>
+                <span class="n">index_column</span><span class="o">=</span><span class="n">index_column</span><span class="p">,</span>
+            <span class="p">)</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
+                <span class="sa">f</span><span class="s2">"File </span><span class="si">{</span><span class="n">db</span><span class="si">}</span><span class="s2"> with format </span><span class="si">{</span><span class="n">db</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">'.'</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="si">}</span><span class="s2"> not supported."</span>
+            <span class="p">)</span>
+        <span class="n">selection</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">_get_all_indices</span><span class="p">()</span>
+
+    <span class="c1"># Perform train/validation split</span>
+    <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">db</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
+        <span class="n">df_for_shuffle</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">(</span>
+            <span class="p">{</span><span class="s2">"event_no"</span><span class="p">:</span> <span class="n">selection</span><span class="p">,</span> <span class="s2">"db"</span><span class="p">:</span> <span class="n">database_indices</span><span class="p">}</span>
+        <span class="p">)</span>
+        <span class="n">shuffled_df</span> <span class="o">=</span> <span class="n">df_for_shuffle</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span>
+            <span class="n">frac</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="n">rng</span>
+        <span class="p">)</span>
+        <span class="n">training_df</span><span class="p">,</span> <span class="n">validation_df</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span>
+            <span class="n">shuffled_df</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="n">test_size</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="n">seed</span>
+        <span class="p">)</span>
+        <span class="n">training_selection</span> <span class="o">=</span> <span class="n">training_df</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
+        <span class="n">validation_selection</span> <span class="o">=</span> <span class="n">validation_df</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
+    <span class="k">else</span><span class="p">:</span>
+        <span class="n">training_selection</span><span class="p">,</span> <span class="n">validation_selection</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span>
+            <span class="n">selection</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="n">test_size</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="n">seed</span>
+        <span class="p">)</span>
+
+    <span class="c1"># Create DataLoaders</span>
+    <span class="n">common_kwargs</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span>
+        <span class="n">db</span><span class="o">=</span><span class="n">db</span><span class="p">,</span>
+        <span class="n">pulsemaps</span><span class="o">=</span><span class="n">pulsemaps</span><span class="p">,</span>
+        <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">,</span>
+        <span class="n">truth</span><span class="o">=</span><span class="n">truth</span><span class="p">,</span>
+        <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
+        <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
+        <span class="n">persistent_workers</span><span class="o">=</span><span class="n">persistent_workers</span><span class="p">,</span>
+        <span class="n">node_truth</span><span class="o">=</span><span class="n">node_truth</span><span class="p">,</span>
+        <span class="n">truth_table</span><span class="o">=</span><span class="n">truth_table</span><span class="p">,</span>
+        <span class="n">node_truth_table</span><span class="o">=</span><span class="n">node_truth_table</span><span class="p">,</span>
+        <span class="n">string_selection</span><span class="o">=</span><span class="n">string_selection</span><span class="p">,</span>
+        <span class="n">loss_weight_column</span><span class="o">=</span><span class="n">loss_weight_column</span><span class="p">,</span>
+        <span class="n">loss_weight_table</span><span class="o">=</span><span class="n">loss_weight_table</span><span class="p">,</span>
+        <span class="n">index_column</span><span class="o">=</span><span class="n">index_column</span><span class="p">,</span>
+        <span class="n">labels</span><span class="o">=</span><span class="n">labels</span><span class="p">,</span>
+        <span class="n">graph_definition</span><span class="o">=</span><span class="n">graph_definition</span><span class="p">,</span>
+    <span class="p">)</span>
+
+    <span class="n">training_dataloader</span> <span class="o">=</span> <span class="n">make_dataloader</span><span class="p">(</span>
+        <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
+        <span class="n">selection</span><span class="o">=</span><span class="n">training_selection</span><span class="p">,</span>
+        <span class="o">**</span><span class="n">common_kwargs</span><span class="p">,</span>  <span class="c1"># type: ignore[arg-type]</span>
+    <span class="p">)</span>
+
+    <span class="n">validation_dataloader</span> <span class="o">=</span> <span class="n">make_dataloader</span><span class="p">(</span>
+        <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
+        <span class="n">selection</span><span class="o">=</span><span class="n">validation_selection</span><span class="p">,</span>
+        <span class="o">**</span><span class="n">common_kwargs</span><span class="p">,</span>  <span class="c1"># type: ignore[arg-type]</span>
+    <span class="p">)</span>
+
+    <span class="k">return</span> <span class="p">(</span>
+        <span class="n">training_dataloader</span><span class="p">,</span>
+        <span class="n">validation_dataloader</span><span class="p">,</span>
+    <span class="p">)</span></div>
+
+
+
+<span class="c1"># @TODO: Remove in favour of Model.predict{,_as_dataframe}</span>
+<div class="viewcode-block" id="get_predictions">
+<a class="viewcode-back" href="../../../api/graphnet.training.utils.html#graphnet.training.utils.get_predictions">[docs]</a>
+<span class="k">def</span> <span class="nf">get_predictions</span><span class="p">(</span>
+    <span class="n">trainer</span><span class="p">:</span> <span class="n">Trainer</span><span class="p">,</span>
+    <span class="n">model</span><span class="p">:</span> <span class="n">Model</span><span class="p">,</span>
+    <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span>
+    <span class="n">prediction_columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+    <span class="o">*</span><span class="p">,</span>
+    <span class="n">node_level</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+    <span class="n">additional_attributes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Get `model` predictions on `dataloader`."""</span>
+    <span class="c1"># Gets predictions from model on the events in the dataloader.</span>
+    <span class="c1"># NOTE: dataloader must NOT have shuffle = True!</span>
+
+    <span class="c1"># Check(s)</span>
+    <span class="k">if</span> <span class="n">additional_attributes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="n">additional_attributes</span> <span class="o">=</span> <span class="p">[]</span>
+    <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">additional_attributes</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
+
+    <span class="c1"># Set model to inference mode</span>
+    <span class="n">model</span><span class="o">.</span><span class="n">inference</span><span class="p">()</span>
+
+    <span class="c1"># Get predictions</span>
+    <span class="n">predictions_torch</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">dataloader</span><span class="p">)</span>
+    <span class="n">predictions_list</span> <span class="o">=</span> <span class="p">[</span>
+        <span class="n">p</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">predictions_torch</span>
+    <span class="p">]</span>  <span class="c1"># Assuming single task</span>
+    <span class="n">predictions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">predictions_list</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
+    <span class="k">try</span><span class="p">:</span>
+        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">prediction_columns</span><span class="p">)</span> <span class="o">==</span> <span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
+    <span class="k">except</span> <span class="ne">IndexError</span><span class="p">:</span>
+        <span class="n">predictions</span> <span class="o">=</span> <span class="n">predictions</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
+        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">prediction_columns</span><span class="p">)</span> <span class="o">==</span> <span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
+
+    <span class="c1"># Get additional attributes</span>
+    <span class="n">attributes</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]]</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">(</span>
+        <span class="p">[(</span><span class="n">attr</span><span class="p">,</span> <span class="p">[])</span> <span class="k">for</span> <span class="n">attr</span> <span class="ow">in</span> <span class="n">additional_attributes</span><span class="p">]</span>
+    <span class="p">)</span>
+    <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">dataloader</span><span class="p">:</span>
+        <span class="k">for</span> <span class="n">attr</span> <span class="ow">in</span> <span class="n">attributes</span><span class="p">:</span>
+            <span class="n">attribute</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">attr</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
+            <span class="k">if</span> <span class="n">node_level</span><span class="p">:</span>
+                <span class="k">if</span> <span class="n">attr</span> <span class="o">==</span> <span class="s2">"event_no"</span><span class="p">:</span>
+                    <span class="n">attribute</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span>
+                        <span class="n">attribute</span><span class="p">,</span> <span class="n">batch</span><span class="p">[</span><span class="s2">"n_pulses"</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
+                    <span class="p">)</span>
+            <span class="n">attributes</span><span class="p">[</span><span class="n">attr</span><span class="p">]</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">attribute</span><span class="p">)</span>
+
+    <span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
+        <span class="p">[</span><span class="n">predictions</span><span class="p">]</span>
+        <span class="o">+</span> <span class="p">[</span>
+            <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">values</span><span class="p">)[:,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="k">for</span> <span class="n">values</span> <span class="ow">in</span> <span class="n">attributes</span><span class="o">.</span><span class="n">values</span><span class="p">()</span>
+        <span class="p">],</span>
+        <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
+    <span class="p">)</span>
+
+    <span class="n">results</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">(</span>
+        <span class="n">data</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="n">prediction_columns</span> <span class="o">+</span> <span class="n">additional_attributes</span>
+    <span class="p">)</span>
+    <span class="k">return</span> <span class="n">results</span></div>
+
+
+
+<span class="c1"># @TODO: Remove</span>
+<div class="viewcode-block" id="save_results">
+<a class="viewcode-back" href="../../../api/graphnet.training.utils.html#graphnet.training.utils.save_results">[docs]</a>
+<span class="k">def</span> <span class="nf">save_results</span><span class="p">(</span>
+    <span class="n">db</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">tag</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">results</span><span class="p">:</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">,</span> <span class="n">archive</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">Model</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Save trained model and prediction `results` in `db`."""</span>
+    <span class="n">db_name</span> <span class="o">=</span> <span class="n">db</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"/"</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"."</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
+    <span class="n">path</span> <span class="o">=</span> <span class="n">archive</span> <span class="o">+</span> <span class="s2">"/"</span> <span class="o">+</span> <span class="n">db_name</span> <span class="o">+</span> <span class="s2">"/"</span> <span class="o">+</span> <span class="n">tag</span>
+    <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+    <span class="n">results</span><span class="o">.</span><span class="n">to_csv</span><span class="p">(</span><span class="n">path</span> <span class="o">+</span> <span class="s2">"/results.csv"</span><span class="p">)</span>
+    <span class="n">model</span><span class="o">.</span><span class="n">save_state_dict</span><span class="p">(</span><span class="n">path</span> <span class="o">+</span> <span class="s2">"/"</span> <span class="o">+</span> <span class="n">tag</span> <span class="o">+</span> <span class="s2">"_state_dict.pth"</span><span class="p">)</span>
+    <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">path</span> <span class="o">+</span> <span class="s2">"/"</span> <span class="o">+</span> <span class="n">tag</span> <span class="o">+</span> <span class="s2">"_model.pth"</span><span class="p">)</span>
+    <span class="n">Logger</span><span class="p">()</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">"Results saved at: </span><span class="se">\n</span><span class="s2"> </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">path</span><span class="p">)</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/training/weight_fitting.html b/_modules/graphnet/training/weight_fitting.html
index 9811e69cf..909eb3dbe 100644
--- a/_modules/graphnet/training/weight_fitting.html
+++ b/_modules/graphnet/training/weight_fitting.html
@@ -556,7 +556,7 @@ <h1 id="modules-graphnet-training-weight-fitting--page-root">Source code for gra
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/utilities/argparse.html b/_modules/graphnet/utilities/argparse.html
index 43877f0e5..a68f077b2 100644
--- a/_modules/graphnet/utilities/argparse.html
+++ b/_modules/graphnet/utilities/argparse.html
@@ -515,7 +515,7 @@ <h1 id="modules-graphnet-utilities-argparse--page-root">Source code for graphnet
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/utilities/config/base_config.html b/_modules/graphnet/utilities/config/base_config.html
new file mode 100644
index 000000000..f57c1f185
--- /dev/null
+++ b/_modules/graphnet/utilities/config/base_config.html
@@ -0,0 +1,449 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.utilities.config.base_config &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/utilities/config/base_config" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.utilities.config.base_config </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-utilities-config-base-config--page-root">Source code for graphnet.utilities.config.base_config</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Base config class(es)."""</span>
+
+<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span>
+<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">OrderedDict</span>
+<span class="kn">import</span> <span class="nn">inspect</span>
+<span class="kn">import</span> <span class="nn">sys</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Optional</span>
+
+<span class="kn">from</span> <span class="nn">pydantic</span> <span class="kn">import</span> <span class="n">BaseModel</span>
+<span class="kn">import</span> <span class="nn">ruamel.yaml</span> <span class="k">as</span> <span class="nn">yaml</span>
+
+
+<span class="n">CONFIG_FILES_SUFFIXES</span> <span class="o">=</span> <span class="p">(</span><span class="s2">".yml"</span><span class="p">,</span> <span class="s2">".yaml"</span><span class="p">)</span>
+
+
+<div class="viewcode-block" id="BaseConfig">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig">[docs]</a>
+<span class="k">class</span> <span class="nc">BaseConfig</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Base class for Configs."""</span>
+
+<div class="viewcode-block" id="BaseConfig.load">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.load">[docs]</a>
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">"BaseConfig"</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Load BaseConfig from `path`."""</span>
+        <span class="k">assert</span> <span class="n">path</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span>
+            <span class="n">CONFIG_FILES_SUFFIXES</span>
+        <span class="p">),</span> <span class="s2">"Please specify YAML config file."</span>
+        <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"r"</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
+            <span class="n">yaml_</span> <span class="o">=</span> <span class="n">yaml</span><span class="o">.</span><span class="n">YAML</span><span class="p">(</span><span class="n">typ</span><span class="o">=</span><span class="s2">"safe"</span><span class="p">,</span> <span class="n">pure</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+            <span class="n">config_dict</span> <span class="o">=</span> <span class="n">yaml_</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
+
+        <span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">config_dict</span><span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="BaseConfig.dump">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.dump">[docs]</a>
+    <span class="k">def</span> <span class="nf">dump</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Save BaseConfig to `path` as YAML file, or return as string."""</span>
+        <span class="n">config_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">as_dict</span><span class="p">()[</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">]</span>
+        <span class="n">yaml_</span> <span class="o">=</span> <span class="n">yaml</span><span class="o">.</span><span class="n">YAML</span><span class="p">(</span><span class="n">typ</span><span class="o">=</span><span class="s2">"safe"</span><span class="p">,</span> <span class="n">pure</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">path</span><span class="p">:</span>
+            <span class="k">if</span> <span class="ow">not</span> <span class="n">path</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="n">CONFIG_FILES_SUFFIXES</span><span class="p">):</span>
+                <span class="n">path</span> <span class="o">+=</span> <span class="n">CONFIG_FILES_SUFFIXES</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
+            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"w"</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
+                <span class="n">yaml_</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">config_dict</span><span class="p">,</span> <span class="n">f</span><span class="p">)</span>
+            <span class="k">return</span> <span class="kc">None</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">return</span> <span class="n">yaml_</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">config_dict</span><span class="p">,</span> <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="BaseConfig.as_dict">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.as_dict">[docs]</a>
+    <span class="k">def</span> <span class="nf">as_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]:</span>
+<span class="w">        </span><span class="sd">"""Represent BaseConfig as a dict.</span>
+
+<span class="sd">        This builds on `BaseModel.dict()` but can be overwritten.</span>
+<span class="sd">        """</span>
+        <span class="k">return</span> <span class="p">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">dict</span><span class="p">()}</span></div>
+</div>
+
+
+
+<div class="viewcode-block" id="get_all_argument_values">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.get_all_argument_values">[docs]</a>
+<span class="k">def</span> <span class="nf">get_all_argument_values</span><span class="p">(</span>
+    <span class="n">fn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]:</span>
+<span class="w">    </span><span class="sd">"""Return dict of all argument values to `fn`, including defaults."""</span>
+    <span class="c1"># Get all default argument values</span>
+    <span class="n">cfg</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">()</span>
+    <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">inspect</span><span class="o">.</span><span class="n">signature</span><span class="p">(</span><span class="n">fn</span><span class="p">)</span><span class="o">.</span><span class="n">parameters</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
+        <span class="c1"># Don't save `self`, `*args`, or `**kwargs`</span>
+        <span class="k">if</span> <span class="n">key</span> <span class="o">==</span> <span class="s2">"self"</span> <span class="ow">or</span> <span class="n">param</span><span class="o">.</span><span class="n">kind</span> <span class="ow">in</span> <span class="p">[</span>
+            <span class="n">param</span><span class="o">.</span><span class="n">VAR_POSITIONAL</span><span class="p">,</span>
+            <span class="n">param</span><span class="o">.</span><span class="n">VAR_KEYWORD</span><span class="p">,</span>
+        <span class="p">]:</span>
+            <span class="k">continue</span>
+        <span class="n">cfg</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">default</span>
+
+    <span class="c1"># Add positional arguments</span>
+    <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">cfg</span><span class="o">.</span><span class="n">keys</span><span class="p">(),</span> <span class="n">args</span><span class="p">):</span>
+        <span class="n">cfg</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">val</span>
+
+    <span class="c1"># Add keyword arguments</span>
+    <span class="n">cfg</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">kwargs</span><span class="p">)</span>
+
+    <span class="k">return</span> <span class="n">cfg</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/utilities/config/configurable.html b/_modules/graphnet/utilities/config/configurable.html
new file mode 100644
index 000000000..cfe023b65
--- /dev/null
+++ b/_modules/graphnet/utilities/config/configurable.html
@@ -0,0 +1,408 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.utilities.config.configurable &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/utilities/config/configurable" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.utilities.config.configurable </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-utilities-config-configurable--page-root">Source code for graphnet.utilities.config.configurable</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Bases for all configurable classes in  `graphnet`."""</span>
+
+<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractclassmethod</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Union</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config.base_config</span> <span class="kn">import</span> <span class="n">BaseConfig</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.decorators</span> <span class="kn">import</span> <span class="n">final</span>
+
+
+<div class="viewcode-block" id="Configurable">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable">[docs]</a>
+<span class="k">class</span> <span class="nc">Configurable</span><span class="p">(</span><span class="n">ABC</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Base class for all configurable classes in graphnet."""</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct `Configurable`."""</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_config</span><span class="p">:</span> <span class="n">BaseConfig</span>
+
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
+
+    <span class="nd">@final</span>
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">config</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">BaseConfig</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Return configuration to re-create the instance."""</span>
+        <span class="k">try</span><span class="p">:</span>
+            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_config</span>
+        <span class="k">except</span> <span class="ne">AttributeError</span><span class="p">:</span>
+            <span class="k">raise</span> <span class="ne">AttributeError</span><span class="p">(</span>
+                <span class="s2">"Config was not set. "</span>
+                <span class="s2">"Did you wrap the class constructor with `save_config`?"</span>
+            <span class="p">)</span>
+
+<div class="viewcode-block" id="Configurable.save_config">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable.save_config">[docs]</a>
+    <span class="nd">@final</span>
+    <span class="k">def</span> <span class="nf">save_config</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Save Config to `path` as YAML file."""</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">path</span><span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="Configurable.from_config">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable.from_config">[docs]</a>
+    <span class="nd">@abstractclassmethod</span>
+    <span class="k">def</span> <span class="nf">from_config</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">source</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">BaseConfig</span><span class="p">,</span> <span class="nb">str</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Any</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct instance from `source` configuration."""</span></div>
+</div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/utilities/config/dataset_config.html b/_modules/graphnet/utilities/config/dataset_config.html
new file mode 100644
index 000000000..91b338109
--- /dev/null
+++ b/_modules/graphnet/utilities/config/dataset_config.html
@@ -0,0 +1,585 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.utilities.config.dataset_config &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/utilities/config/dataset_config" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.utilities.config.dataset_config </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-utilities-config-dataset-config--page-root">Source code for graphnet.utilities.config.dataset_config</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Config classes for the `graphnet.data.dataset` module."""</span>
+
+<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">wraps</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">TYPE_CHECKING</span><span class="p">,</span>
+    <span class="n">Any</span><span class="p">,</span>
+    <span class="n">Callable</span><span class="p">,</span>
+    <span class="n">Dict</span><span class="p">,</span>
+    <span class="n">List</span><span class="p">,</span>
+    <span class="n">Optional</span><span class="p">,</span>
+    <span class="n">Union</span><span class="p">,</span>
+<span class="p">)</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config.base_config</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">BaseConfig</span><span class="p">,</span>
+    <span class="n">get_all_argument_values</span><span class="p">,</span>
+<span class="p">)</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config.parsing</span> <span class="kn">import</span> <span class="n">traverse_and_apply</span>
+<span class="kn">from</span> <span class="nn">.model_config</span> <span class="kn">import</span> <span class="n">ModelConfig</span>
+
+<span class="k">if</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span>
+    <span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span>
+
+
+<span class="n">BACKEND_LOOKUP</span> <span class="o">=</span> <span class="p">{</span>
+    <span class="s2">"db"</span><span class="p">:</span> <span class="s2">"sqlite"</span><span class="p">,</span>
+    <span class="s2">"parquet"</span><span class="p">:</span> <span class="s2">"parquet"</span><span class="p">,</span>
+<span class="p">}</span>
+
+
+<div class="viewcode-block" id="DatasetConfig">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig">[docs]</a>
+<span class="k">class</span> <span class="nc">DatasetConfig</span><span class="p">(</span><span class="n">BaseConfig</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Configuration for all `Dataset`s."""</span>
+
+    <span class="c1"># Fields</span>
+    <span class="n">path</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span>
+    <span class="n">pulsemaps</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span>
+    <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span>
+    <span class="n">truth</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span>
+    <span class="n">node_truth</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
+    <span class="n">index_column</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"event_no"</span>
+    <span class="n">truth_table</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"truth"</span>
+    <span class="n">node_truth_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+    <span class="n">string_selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
+    <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span>
+        <span class="n">Union</span><span class="p">[</span>
+            <span class="nb">str</span><span class="p">,</span>
+            <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span>
+            <span class="n">List</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]],</span>
+            <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]],</span>
+        <span class="p">]</span>
+    <span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+    <span class="n">loss_weight_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+    <span class="n">loss_weight_column</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+    <span class="n">loss_weight_default_value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+    <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
+    <span class="n">graph_definition</span><span class="p">:</span> <span class="n">Any</span> <span class="o">=</span> <span class="kc">None</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">data</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct `DataConfig`.</span>
+
+<span class="sd">        Can be used for dataset configuration as code, thereby making dataset</span>
+<span class="sd">        construction more transparent and reproducible.</span>
+
+<span class="sd">        Examples:</span>
+<span class="sd">            In one session, do:</span>
+
+<span class="sd">            &gt;&gt;&gt; dataset = Dataset(...)</span>
+<span class="sd">            &gt;&gt;&gt; dataset.config.dump()</span>
+<span class="sd">            path: (...)</span>
+<span class="sd">            pulsemaps:</span>
+<span class="sd">                - (...)</span>
+<span class="sd">            (...)</span>
+<span class="sd">            &gt;&gt;&gt; dataset.config.dump("dataset.yml")</span>
+
+<span class="sd">            In another session, you can then do:</span>
+<span class="sd">            &gt;&gt;&gt; dataset = Dataset.from_config("dataset.yml")</span>
+
+<span class="sd">            # Uniquely for `DatasetConfig`, you can also define and load</span>
+<span class="sd">            # multiple datasets</span>
+<span class="sd">            &gt;&gt;&gt; dataset.config.selection = {</span>
+<span class="sd">                "train": "event_no % 2 == 0",</span>
+<span class="sd">                "test": "event_no % 2 == 1",</span>
+<span class="sd">            }</span>
+<span class="sd">            &gt;&gt;&gt; dataset.config.dump("dataset.yml")</span>
+<span class="sd">            &gt;&gt;&gt; datasets: Dict[str, Dataset] = Dataset.from_config(</span>
+<span class="sd">                "dataset.yml"</span>
+<span class="sd">            )</span>
+<span class="sd">            &gt;&gt;&gt; datasets</span>
+<span class="sd">            {</span>
+<span class="sd">                "train": Dataset(...),</span>
+<span class="sd">                "test": Dataset(...),</span>
+<span class="sd">            }</span>
+
+<span class="sd">            # You can also combine multiple selections into a single, named</span>
+<span class="sd">            # dataset</span>
+<span class="sd">            &gt;&gt;&gt; dataset.config.selection = {</span>
+<span class="sd">                "train": [</span>
+<span class="sd">                    "event_no % 2 == 0 &amp; abs(pid) == 12",</span>
+<span class="sd">                    "event_no % 2 == 0 &amp; abs(pid) == 14",</span>
+<span class="sd">                    "event_no % 2 == 0 &amp; abs(pid) == 16",</span>
+<span class="sd">                ],</span>
+<span class="sd">                (...)</span>
+<span class="sd">            }</span>
+<span class="sd">            &gt;&gt;&gt; dataset.config.dump("dataset.yml")</span>
+<span class="sd">            &gt;&gt;&gt; datasets: Dict[str, EnsembleDataset] = Dataset.from_config(</span>
+<span class="sd">                "dataset.yml"</span>
+<span class="sd">            )</span>
+<span class="sd">            &gt;&gt;&gt; datasets</span>
+<span class="sd">            {</span>
+<span class="sd">                "train": EnsembleDataset(...),</span>
+<span class="sd">                (...)</span>
+<span class="sd">            }</span>
+
+<span class="sd">            # Finally, you can still reference existing selection files in CSV</span>
+<span class="sd">            # or JSON formats:</span>
+<span class="sd">            &gt;&gt;&gt; dataset.config.selection = {</span>
+<span class="sd">                "train": "50000 random events ~ train_selection.csv",</span>
+<span class="sd">                "test": "test_selection.csv",</span>
+<span class="sd">            }</span>
+<span class="sd">        """</span>
+        <span class="c1"># Single-key dictioaries are unpacked</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="s2">"selection"</span><span class="p">],</span> <span class="nb">dict</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="s2">"selection"</span><span class="p">])</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
+            <span class="n">data</span><span class="p">[</span><span class="s2">"selection"</span><span class="p">]</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="s2">"selection"</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="p">()))</span>
+
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span>
+
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">_backend</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
+        <span class="n">path</span><span class="p">:</span> <span class="nb">str</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">path</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
+            <span class="n">path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">path</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">path</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span>
+            <span class="n">path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">path</span>
+        <span class="n">suffix</span> <span class="o">=</span> <span class="n">path</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"."</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
+        <span class="k">try</span><span class="p">:</span>
+            <span class="k">return</span> <span class="n">BACKEND_LOOKUP</span><span class="p">[</span><span class="n">suffix</span><span class="p">]</span>
+        <span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
+                <span class="sa">f</span><span class="s2">"Dataset at `path` </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">path</span><span class="si">}</span><span class="s2"> with suffix </span><span class="si">{</span><span class="n">suffix</span><span class="si">}</span><span class="s2"> not "</span>
+                <span class="s2">"supported."</span>
+            <span class="p">)</span>
+            <span class="k">raise</span>
+
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">_dataset_class</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">type</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Return the `Dataset` class implementation for this configuration."""</span>
+        <span class="kn">from</span> <span class="nn">graphnet.data.dataset.sqlite</span> <span class="kn">import</span> <span class="n">SQLiteDataset</span>
+        <span class="kn">from</span> <span class="nn">graphnet.data.dataset.parquet</span> <span class="kn">import</span> <span class="n">ParquetDataset</span>
+
+        <span class="n">dataset_class</span> <span class="o">=</span> <span class="p">{</span>
+            <span class="s2">"sqlite"</span><span class="p">:</span> <span class="n">SQLiteDataset</span><span class="p">,</span>
+            <span class="s2">"parquet"</span><span class="p">:</span> <span class="n">ParquetDataset</span><span class="p">,</span>
+        <span class="p">}[</span><span class="bp">self</span><span class="o">.</span><span class="n">_backend</span><span class="p">]</span>
+
+        <span class="k">return</span> <span class="n">dataset_class</span>
+
+<div class="viewcode-block" id="DatasetConfig.as_dict">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.as_dict">[docs]</a>
+    <span class="k">def</span> <span class="nf">as_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]:</span>
+<span class="w">        </span><span class="sd">"""Represent ModelConfig as a dict.</span>
+
+<span class="sd">        This builds on `BaseModel.dict()` but wraps the output in a single-key</span>
+<span class="sd">        dictionary to make it unambiguous to identify model arguments that are</span>
+<span class="sd">        themselves models.</span>
+<span class="sd">        """</span>
+        <span class="n">config_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dict</span><span class="p">()</span>
+        <span class="n">config_dict</span> <span class="o">=</span> <span class="n">traverse_and_apply</span><span class="p">(</span>
+            <span class="n">obj</span><span class="o">=</span><span class="nb">dict</span><span class="p">(</span><span class="o">**</span><span class="n">config_dict</span><span class="p">),</span> <span class="n">fn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_parse_torch</span>
+        <span class="p">)</span>
+        <span class="k">return</span> <span class="p">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">:</span> <span class="n">config_dict</span><span class="p">}</span></div>
+
+
+    <span class="k">def</span> <span class="nf">_parse_torch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">obj</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Any</span><span class="p">:</span>
+        <span class="kn">import</span> <span class="nn">torch</span>
+
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">):</span>
+            <span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="fm">__str__</span><span class="p">()</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">return</span> <span class="n">obj</span></div>
+
+
+
+<div class="viewcode-block" id="save_dataset_config">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.save_dataset_config">[docs]</a>
+<span class="k">def</span> <span class="nf">save_dataset_config</span><span class="p">(</span><span class="n">init_fn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Callable</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Save the arguments to `__init__` functions as member `DatasetConfig`."""</span>
+
+    <span class="k">def</span> <span class="nf">_replace_model_instance_with_config</span><span class="p">(</span>
+        <span class="n">obj</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s2">"Model"</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="n">ModelConfig</span><span class="p">,</span> <span class="n">Any</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Replace `Model` instances in `obj` with their `ModelConfig`."""</span>
+        <span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span>
+        <span class="kn">import</span> <span class="nn">torch</span>
+
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">Model</span><span class="p">):</span>
+            <span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="n">config</span>
+
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">):</span>
+            <span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="fm">__str__</span><span class="p">()</span>
+
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">return</span> <span class="n">obj</span>
+
+    <span class="nd">@wraps</span><span class="p">(</span><span class="n">init_fn</span><span class="p">)</span>
+    <span class="k">def</span> <span class="nf">wrapper</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Any</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Set `DatasetConfig` after calling `init_fn`."""</span>
+        <span class="c1"># Call wrapped method</span>
+        <span class="n">ret</span> <span class="o">=</span> <span class="n">init_fn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+
+        <span class="c1"># Get all argument values, including defaults</span>
+        <span class="n">cfg</span> <span class="o">=</span> <span class="n">get_all_argument_values</span><span class="p">(</span><span class="n">init_fn</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+
+        <span class="c1"># Handle nested `Model`s, etc.</span>
+        <span class="n">cfg</span> <span class="o">=</span> <span class="n">traverse_and_apply</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">_replace_model_instance_with_config</span><span class="p">)</span>
+        <span class="c1"># Add `DatasetConfig` as member variables</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_config</span> <span class="o">=</span> <span class="n">DatasetConfig</span><span class="p">(</span><span class="o">**</span><span class="n">cfg</span><span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">ret</span>
+
+    <span class="k">return</span> <span class="n">wrapper</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/utilities/config/model_config.html b/_modules/graphnet/utilities/config/model_config.html
new file mode 100644
index 000000000..e63e73186
--- /dev/null
+++ b/_modules/graphnet/utilities/config/model_config.html
@@ -0,0 +1,654 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.utilities.config.model_config &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/utilities/config/model_config" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.utilities.config.model_config </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-utilities-config-model-config--page-root">Source code for graphnet.utilities.config.model_config</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Config classes for the `graphnet.models` module."""</span>
+<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">wraps</span>
+<span class="kn">import</span> <span class="nn">inspect</span>
+<span class="kn">import</span> <span class="nn">re</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">TYPE_CHECKING</span><span class="p">,</span>
+    <span class="n">Any</span><span class="p">,</span>
+    <span class="n">Callable</span><span class="p">,</span>
+    <span class="n">Dict</span><span class="p">,</span>
+    <span class="n">List</span><span class="p">,</span>
+    <span class="n">Optional</span><span class="p">,</span>
+    <span class="n">Union</span><span class="p">,</span>
+<span class="p">)</span>
+<span class="kn">import</span> <span class="nn">torch</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config.base_config</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">BaseConfig</span><span class="p">,</span>
+    <span class="n">get_all_argument_values</span><span class="p">,</span>
+<span class="p">)</span>
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config.parsing</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">traverse_and_apply</span><span class="p">,</span>
+    <span class="n">get_all_grapnet_classes</span><span class="p">,</span>
+<span class="p">)</span>
+
+<span class="k">if</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span>
+    <span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span>
+
+
+<span class="n">FUNCTION_DEFINITION_PATTERN</span> <span class="o">=</span> <span class="p">(</span>
+    <span class="sa">r</span><span class="s2">"^def (?P&lt;function_name&gt;[a-zA-Z]</span><span class="si">{1}</span><span class="s2">[a-zA-Z0-9_]+) *\(.*\) *:"</span>
+<span class="p">)</span>
+
+
+<div class="viewcode-block" id="ModelConfig">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig">[docs]</a>
+<span class="k">class</span> <span class="nc">ModelConfig</span><span class="p">(</span><span class="n">BaseConfig</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Configuration for all `Model`s."""</span>
+
+    <span class="c1"># Fields</span>
+    <span class="n">class_name</span><span class="p">:</span> <span class="nb">str</span>
+    <span class="n">arguments</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">data</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct `ModelConfig`.</span>
+
+<span class="sd">        Can be used for model configuration as code, thereby making model</span>
+<span class="sd">        construction more transparent and reproducible. Note that this does</span>
+<span class="sd">        *not* save any trainable weights, meaning this is only a configuration</span>
+<span class="sd">        for the model's hyperparameters. Any model instantiated from a</span>
+<span class="sd">        ModelConfig or file will be randomly initialised, and thus should be</span>
+<span class="sd">        trained.</span>
+
+<span class="sd">        Examples:</span>
+<span class="sd">            In one session, do:</span>
+
+<span class="sd">            &gt;&gt;&gt; model = Model(...)</span>
+<span class="sd">            &gt;&gt;&gt; model.config.dump()</span>
+<span class="sd">            arguments:</span>
+<span class="sd">                - (...): (...)</span>
+<span class="sd">            class_name: Model</span>
+<span class="sd">            &gt;&gt;&gt; model.config.dump("model.yml")</span>
+
+<span class="sd">            In another session, you can then do:</span>
+<span class="sd">            &gt;&gt;&gt; model = Model.from_config("model.yml")</span>
+<span class="sd">        """</span>
+        <span class="c1"># Parse any nested `ModelConfig` arguments</span>
+        <span class="k">for</span> <span class="n">arg</span> <span class="ow">in</span> <span class="n">data</span><span class="p">[</span><span class="s2">"arguments"</span><span class="p">]:</span>
+            <span class="n">value</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s2">"arguments"</span><span class="p">][</span><span class="n">arg</span><span class="p">]</span>
+            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">)):</span>
+                <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="n">elem</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">value</span><span class="p">):</span>
+                    <span class="n">data</span><span class="p">[</span><span class="s2">"arguments"</span><span class="p">][</span><span class="n">arg</span><span class="p">][</span>
+                        <span class="n">ix</span>
+                    <span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_parse_if_model_config_entry</span><span class="p">(</span><span class="n">elem</span><span class="p">)</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="n">data</span><span class="p">[</span><span class="s2">"arguments"</span><span class="p">][</span><span class="n">arg</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_parse_if_model_config_entry</span><span class="p">(</span>
+                    <span class="n">value</span>
+                <span class="p">)</span>
+        <span class="c1"># Base class constructor</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_is_model_config_entry</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">entry</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Check whether dictionary entry is a `ModelConfig`."""</span>
+        <span class="k">return</span> <span class="p">(</span>
+            <span class="nb">isinstance</span><span class="p">(</span><span class="n">entry</span><span class="p">,</span> <span class="nb">dict</span><span class="p">)</span>
+            <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">entry</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span>
+            <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> <span class="ow">in</span> <span class="n">entry</span>
+        <span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_parse_if_model_config_entry</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">entry</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="s2">"ModelConfig"</span><span class="p">,</span> <span class="n">Any</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Parse dictionary entry to `ModelConfig`."""</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_is_model_config_entry</span><span class="p">(</span><span class="n">entry</span><span class="p">):</span>
+            <span class="n">config_dict</span> <span class="o">=</span> <span class="n">entry</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">]</span>
+            <span class="n">config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="p">(</span><span class="o">**</span><span class="n">config_dict</span><span class="p">)</span>
+            <span class="k">return</span> <span class="n">config</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">return</span> <span class="n">entry</span>
+
+    <span class="k">def</span> <span class="nf">_construct_model</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">trust</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">load_modules</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">"Model"</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Construct `Model` instance from `self` configuration.</span>
+
+<span class="sd">        Used as the basis for `Model.from_config`.</span>
+<span class="sd">        """</span>
+        <span class="c1"># Check(s)</span>
+        <span class="k">if</span> <span class="n">load_modules</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">load_modules</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"torch"</span><span class="p">]</span>
+        <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">load_modules</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
+
+        <span class="c1"># Load any additional modules into the global namespace</span>
+        <span class="k">for</span> <span class="n">module</span> <span class="ow">in</span> <span class="n">load_modules</span><span class="p">:</span>
+            <span class="k">assert</span> <span class="n">re</span><span class="o">.</span><span class="n">match</span><span class="p">(</span><span class="s2">"^[a-zA-Z_]+$"</span><span class="p">,</span> <span class="n">module</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
+            <span class="k">if</span> <span class="n">module</span> <span class="ow">in</span> <span class="nb">globals</span><span class="p">():</span>
+                <span class="k">continue</span>
+            <span class="n">exec</span><span class="p">(</span><span class="sa">f</span><span class="s2">"import </span><span class="si">{</span><span class="n">module</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="nb">globals</span><span class="p">())</span>
+
+        <span class="c1"># Get a lookup for all classes in `graphnet`</span>
+        <span class="kn">import</span> <span class="nn">graphnet.data</span>
+        <span class="kn">import</span> <span class="nn">graphnet.models</span>
+        <span class="kn">import</span> <span class="nn">graphnet.training</span>
+
+        <span class="n">namespace_classes</span> <span class="o">=</span> <span class="n">get_all_grapnet_classes</span><span class="p">(</span>
+            <span class="n">graphnet</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">graphnet</span><span class="o">.</span><span class="n">models</span><span class="p">,</span> <span class="n">graphnet</span><span class="o">.</span><span class="n">training</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Parse potential ModelConfig arguments</span>
+        <span class="n">arguments</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">arguments</span><span class="p">)</span>
+        <span class="n">arguments</span> <span class="o">=</span> <span class="n">traverse_and_apply</span><span class="p">(</span>
+            <span class="n">arguments</span><span class="p">,</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_deserialise</span><span class="p">,</span>
+            <span class="n">fn_kwargs</span><span class="o">=</span><span class="p">{</span><span class="s2">"trust"</span><span class="p">:</span> <span class="n">trust</span><span class="p">},</span>
+        <span class="p">)</span>
+
+        <span class="c1"># Construct model based on arguments</span>
+        <span class="k">return</span> <span class="n">namespace_classes</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">class_name</span><span class="p">](</span><span class="o">**</span><span class="n">arguments</span><span class="p">)</span>
+
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">_deserialise</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">obj</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">trust</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Any</span><span class="p">:</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">ModelConfig</span><span class="p">):</span>
+            <span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span>
+
+            <span class="k">return</span> <span class="n">Model</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">trust</span><span class="o">=</span><span class="n">trust</span><span class="p">)</span>
+
+        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> <span class="ow">and</span> <span class="n">obj</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"!lambda"</span><span class="p">):</span>
+            <span class="k">if</span> <span class="n">trust</span><span class="p">:</span>
+                <span class="n">source</span> <span class="o">=</span> <span class="n">obj</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span>
+                <span class="n">f</span> <span class="o">=</span> <span class="nb">eval</span><span class="p">(</span><span class="n">source</span><span class="p">)</span>
+
+                <span class="c1"># Save a copy of the source code attached to the callable,</span>
+                <span class="c1"># since the `inspect` module is not able to get the source code</span>
+                <span class="c1"># for functions that are not defined on file.</span>
+                <span class="c1"># See `self._serialise`.</span>
+                <span class="n">f</span><span class="o">.</span><span class="n">_source</span> <span class="o">=</span> <span class="n">source</span>
+                <span class="k">return</span> <span class="n">f</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
+                    <span class="s2">"Constructing model containing a lambda function "</span>
+                    <span class="sa">f</span><span class="s2">"(</span><span class="si">{</span><span class="n">obj</span><span class="si">}</span><span class="s2">) with `trust=False`. If you trust the lambda "</span>
+                    <span class="s2">"functions in this ModelConfig, set `trust=True` and "</span>
+                    <span class="s2">"reconstruct the model again."</span>
+                <span class="p">)</span>
+
+        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> <span class="ow">and</span> <span class="n">obj</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"!function"</span><span class="p">):</span>
+            <span class="k">if</span> <span class="n">trust</span><span class="p">:</span>
+                <span class="n">source</span> <span class="o">=</span> <span class="n">obj</span><span class="p">[</span><span class="mi">10</span><span class="p">:]</span>
+                <span class="n">match</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">match</span><span class="p">(</span><span class="n">FUNCTION_DEFINITION_PATTERN</span><span class="p">,</span> <span class="n">source</span><span class="p">)</span>
+                <span class="k">assert</span> <span class="n">match</span>
+                <span class="n">exec</span><span class="p">(</span><span class="n">source</span><span class="p">)</span>
+                <span class="n">fn</span> <span class="o">=</span> <span class="nb">eval</span><span class="p">(</span><span class="n">match</span><span class="o">.</span><span class="n">group</span><span class="p">(</span><span class="s2">"function_name"</span><span class="p">))</span>
+                <span class="k">return</span> <span class="n">fn</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
+                    <span class="sa">f</span><span class="s2">"Constructing model containing a function (</span><span class="si">{</span><span class="n">obj</span><span class="si">}</span><span class="s2">) with "</span>
+                    <span class="s2">"`trust=False`. If you trust the functions in this "</span>
+                    <span class="s2">"ModelConfig, set `trust=True` and reconstruct the model "</span>
+                    <span class="s2">"again."</span>
+                <span class="p">)</span>
+
+        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> <span class="ow">and</span> <span class="n">obj</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"!class"</span><span class="p">):</span>
+            <span class="k">if</span> <span class="n">trust</span><span class="p">:</span>
+                <span class="n">module</span><span class="p">,</span> <span class="n">class_name</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">split</span><span class="p">()[</span><span class="mi">1</span><span class="p">:]</span>
+                <span class="n">exec</span><span class="p">(</span><span class="sa">f</span><span class="s2">"from </span><span class="si">{</span><span class="n">module</span><span class="si">}</span><span class="s2"> import </span><span class="si">{</span><span class="n">class_name</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
+                <span class="k">return</span> <span class="nb">eval</span><span class="p">(</span><span class="n">class_name</span><span class="p">)</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
+                    <span class="sa">f</span><span class="s2">"Constructing model containing a class (</span><span class="si">{</span><span class="n">obj</span><span class="si">}</span><span class="s2">) with "</span>
+                    <span class="s2">"`trust=False`. If you trust the class definitions in "</span>
+                    <span class="s2">"this ModelConfig, set `trust=True` and reconstruct the "</span>
+                    <span class="s2">"model again."</span>
+                <span class="p">)</span>
+        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> <span class="ow">and</span> <span class="n">obj</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"torch"</span><span class="p">):</span>
+            <span class="k">return</span> <span class="nb">eval</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span>
+
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">return</span> <span class="n">obj</span>
+
+    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">_serialise</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">obj</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Any</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Serialise `obj` to a format that can be saved to file."""</span>
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">ModelConfig</span><span class="p">):</span>
+            <span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="n">as_dict</span><span class="p">()</span>
+        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">type</span><span class="p">):</span>
+            <span class="k">return</span> <span class="sa">f</span><span class="s2">"!class </span><span class="si">{</span><span class="n">obj</span><span class="o">.</span><span class="vm">__module__</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">obj</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">"</span>
+        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">):</span>
+            <span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="fm">__str__</span><span class="p">()</span>
+        <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">Callable</span><span class="p">):</span>  <span class="c1"># type: ignore[arg-type]</span>
+            <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="s2">"__name__"</span><span class="p">)</span> <span class="ow">and</span> <span class="n">obj</span><span class="o">.</span><span class="vm">__name__</span> <span class="o">==</span> <span class="s2">"&lt;lambda&gt;"</span><span class="p">:</span>
+                <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="s2">"_source"</span><span class="p">):</span>
+                    <span class="c1"># If source code is set manually during deserialisation.</span>
+                    <span class="c1"># See `self._deserialise`.</span>
+                    <span class="n">source</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">_source</span>
+                <span class="k">else</span><span class="p">:</span>
+                    <span class="n">source</span> <span class="o">=</span> <span class="n">inspect</span><span class="o">.</span><span class="n">getsource</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"="</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">strip</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2"> ,"</span><span class="p">)</span>
+
+                <span class="k">return</span> <span class="s2">"!"</span> <span class="o">+</span> <span class="n">source</span>
+            <span class="k">else</span><span class="p">:</span>
+                <span class="k">try</span><span class="p">:</span>
+                    <span class="n">source</span> <span class="o">=</span> <span class="n">inspect</span><span class="o">.</span><span class="n">getsource</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span>
+                    <span class="n">match</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">match</span><span class="p">(</span><span class="n">FUNCTION_DEFINITION_PATTERN</span><span class="p">,</span> <span class="n">source</span><span class="p">)</span>
+                    <span class="k">if</span> <span class="n">match</span> <span class="ow">and</span> <span class="n">match</span><span class="o">.</span><span class="n">group</span><span class="p">(</span><span class="s2">"function_name"</span><span class="p">):</span>
+                        <span class="k">return</span> <span class="sa">f</span><span class="s2">"!function </span><span class="si">{</span><span class="n">source</span><span class="si">}</span><span class="s2">"</span>
+                    <span class="k">else</span><span class="p">:</span>
+                        <span class="k">raise</span> <span class="ne">ValueError</span>
+                <span class="k">except</span> <span class="p">(</span><span class="ne">TypeError</span><span class="p">,</span> <span class="ne">ValueError</span><span class="p">):</span>
+                    <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
+                        <span class="sa">f</span><span class="s2">"Object `</span><span class="si">{</span><span class="n">obj</span><span class="si">}</span><span class="s2">` is callable but not a lambda or "</span>
+                        <span class="s2">"regular function. Please wrap in a, e.g., lambda "</span>
+                        <span class="s2">"function to allow for saving this function verbatim "</span>
+                        <span class="s2">"in a model config file."</span>
+                    <span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">obj</span>
+
+<div class="viewcode-block" id="ModelConfig.as_dict">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig.as_dict">[docs]</a>
+    <span class="k">def</span> <span class="nf">as_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]:</span>
+<span class="w">        </span><span class="sd">"""Represent ModelConfig as a dict.</span>
+
+<span class="sd">        This builds on `BaseModel.dict()` but wraps the output in a single-key</span>
+<span class="sd">        dictionary to make it unambiguous to identify model arguments that are</span>
+<span class="sd">        themselves models.</span>
+<span class="sd">        """</span>
+        <span class="n">config_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dict</span><span class="p">()</span>
+        <span class="n">config_dict</span><span class="p">[</span><span class="s2">"arguments"</span><span class="p">]</span> <span class="o">=</span> <span class="n">traverse_and_apply</span><span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">arguments</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_serialise</span>
+        <span class="p">)</span>
+
+        <span class="k">return</span> <span class="p">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">:</span> <span class="n">config_dict</span><span class="p">}</span></div>
+</div>
+
+
+
+<div class="viewcode-block" id="save_model_config">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.save_model_config">[docs]</a>
+<span class="k">def</span> <span class="nf">save_model_config</span><span class="p">(</span><span class="n">init_fn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Callable</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Save the arguments to `__init__` functions as a member `ModelConfig`."""</span>
+
+    <span class="k">def</span> <span class="nf">_replace_model_instance_with_config</span><span class="p">(</span>
+        <span class="n">obj</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s2">"Model"</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="n">ModelConfig</span><span class="p">,</span> <span class="n">Any</span><span class="p">]:</span>
+<span class="w">        </span><span class="sd">"""Replace `Model` instances in `obj` with their `ModelConfig`."""</span>
+        <span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span>
+
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">Model</span><span class="p">):</span>
+            <span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="n">config</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="k">return</span> <span class="n">obj</span>
+
+    <span class="nd">@wraps</span><span class="p">(</span><span class="n">init_fn</span><span class="p">)</span>
+    <span class="k">def</span> <span class="nf">wrapper</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Any</span><span class="p">:</span>
+<span class="w">        </span><span class="sd">"""Set `ModelConfig` after calling `init_fn`."""</span>
+        <span class="c1"># Call wrapped method</span>
+        <span class="n">ret</span> <span class="o">=</span> <span class="n">init_fn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+
+        <span class="c1"># Get all argument values, including defaults</span>
+        <span class="n">cfg</span> <span class="o">=</span> <span class="n">get_all_argument_values</span><span class="p">(</span><span class="n">init_fn</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+
+        <span class="c1"># Handle nested `Model`s, etc.</span>
+        <span class="n">cfg</span> <span class="o">=</span> <span class="n">traverse_and_apply</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">_replace_model_instance_with_config</span><span class="p">)</span>
+
+        <span class="c1"># Add `ModelConfig` as member variables</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_config</span> <span class="o">=</span> <span class="n">ModelConfig</span><span class="p">(</span>
+            <span class="n">class_name</span><span class="o">=</span><span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">),</span>
+            <span class="n">arguments</span><span class="o">=</span><span class="nb">dict</span><span class="p">(</span><span class="o">**</span><span class="n">cfg</span><span class="p">),</span>
+        <span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">ret</span>
+
+    <span class="k">return</span> <span class="n">wrapper</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/utilities/config/parsing.html b/_modules/graphnet/utilities/config/parsing.html
new file mode 100644
index 000000000..e1bff4682
--- /dev/null
+++ b/_modules/graphnet/utilities/config/parsing.html
@@ -0,0 +1,475 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.utilities.config.parsing &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/utilities/config/parsing" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.utilities.config.parsing </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-utilities-config-parsing--page-root">Source code for graphnet.utilities.config.parsing</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Utility functions for parsing for using with Config-classes."""</span>
+
+<span class="kn">import</span> <span class="nn">itertools</span>
+<span class="kn">import</span> <span class="nn">pkgutil</span>
+<span class="kn">import</span> <span class="nn">types</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">Any</span><span class="p">,</span>
+    <span class="n">Callable</span><span class="p">,</span>
+    <span class="n">Dict</span><span class="p">,</span>
+    <span class="n">List</span><span class="p">,</span>
+    <span class="n">Optional</span><span class="p">,</span>
+<span class="p">)</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span>
+
+
+<div class="viewcode-block" id="traverse_and_apply">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.traverse_and_apply">[docs]</a>
+<span class="k">def</span> <span class="nf">traverse_and_apply</span><span class="p">(</span>
+    <span class="n">obj</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">fn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">fn_kwargs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
+<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Any</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Apply `fn` to all elements in `obj`, resulting in same structure."""</span>
+    <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
+        <span class="k">return</span> <span class="p">[</span><span class="n">traverse_and_apply</span><span class="p">(</span><span class="n">elem</span><span class="p">,</span> <span class="n">fn</span><span class="p">,</span> <span class="n">fn_kwargs</span><span class="p">)</span> <span class="k">for</span> <span class="n">elem</span> <span class="ow">in</span> <span class="n">obj</span><span class="p">]</span>
+    <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
+        <span class="k">return</span> <span class="p">{</span>
+            <span class="n">key</span><span class="p">:</span> <span class="n">traverse_and_apply</span><span class="p">(</span><span class="n">val</span><span class="p">,</span> <span class="n">fn</span><span class="p">,</span> <span class="n">fn_kwargs</span><span class="p">)</span>
+            <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">obj</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
+        <span class="p">}</span>
+    <span class="k">else</span><span class="p">:</span>
+        <span class="k">if</span> <span class="n">fn_kwargs</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">fn_kwargs</span> <span class="o">=</span> <span class="p">{}</span>
+        <span class="k">return</span> <span class="n">fn</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="o">**</span><span class="n">fn_kwargs</span><span class="p">)</span></div>
+
+
+
+<div class="viewcode-block" id="list_all_submodules">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.list_all_submodules">[docs]</a>
+<span class="k">def</span> <span class="nf">list_all_submodules</span><span class="p">(</span><span class="o">*</span><span class="n">packages</span><span class="p">:</span> <span class="n">types</span><span class="o">.</span><span class="n">ModuleType</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">types</span><span class="o">.</span><span class="n">ModuleType</span><span class="p">]:</span>
+<span class="w">    </span><span class="sd">"""List all submodules in `packages` recursively."""</span>
+    <span class="c1"># Resolve one or more packages</span>
+    <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">packages</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
+        <span class="k">return</span> <span class="nb">list</span><span class="p">(</span>
+            <span class="n">itertools</span><span class="o">.</span><span class="n">chain</span><span class="o">.</span><span class="n">from_iterable</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="n">list_all_submodules</span><span class="p">,</span> <span class="n">packages</span><span class="p">))</span>
+        <span class="p">)</span>
+    <span class="k">else</span><span class="p">:</span>
+        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">packages</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"No packages specified"</span>
+        <span class="n">package</span> <span class="o">=</span> <span class="n">packages</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
+
+    <span class="n">submodules</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">types</span><span class="o">.</span><span class="n">ModuleType</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
+    <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">module_name</span><span class="p">,</span> <span class="n">is_pkg</span> <span class="ow">in</span> <span class="n">pkgutil</span><span class="o">.</span><span class="n">walk_packages</span><span class="p">(</span>
+        <span class="n">package</span><span class="o">.</span><span class="n">__path__</span><span class="p">,</span> <span class="n">package</span><span class="o">.</span><span class="vm">__name__</span> <span class="o">+</span> <span class="s2">"."</span>
+    <span class="p">):</span>
+        <span class="n">module</span> <span class="o">=</span> <span class="nb">__import__</span><span class="p">(</span><span class="n">module_name</span><span class="p">,</span> <span class="n">fromlist</span><span class="o">=</span><span class="s2">"dummylist"</span><span class="p">)</span>
+        <span class="n">submodules</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">module</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">is_pkg</span><span class="p">:</span>
+            <span class="n">submodules</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">list_all_submodules</span><span class="p">(</span><span class="n">module</span><span class="p">))</span>
+
+    <span class="k">return</span> <span class="n">submodules</span></div>
+
+
+
+<div class="viewcode-block" id="get_all_grapnet_classes">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.get_all_grapnet_classes">[docs]</a>
+<span class="k">def</span> <span class="nf">get_all_grapnet_classes</span><span class="p">(</span><span class="o">*</span><span class="n">packages</span><span class="p">:</span> <span class="n">types</span><span class="o">.</span><span class="n">ModuleType</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">type</span><span class="p">]:</span>
+<span class="w">    </span><span class="sd">"""List all grapnet classes in `packages`."""</span>
+    <span class="n">submodules</span> <span class="o">=</span> <span class="n">list_all_submodules</span><span class="p">(</span><span class="o">*</span><span class="n">packages</span><span class="p">)</span>
+    <span class="n">classes</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">type</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
+    <span class="k">for</span> <span class="n">submodule</span> <span class="ow">in</span> <span class="n">submodules</span><span class="p">:</span>
+        <span class="n">new_classes</span> <span class="o">=</span> <span class="n">get_graphnet_classes</span><span class="p">(</span><span class="n">submodule</span><span class="p">)</span>
+        <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">new_classes</span><span class="p">:</span>
+            <span class="k">if</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">classes</span> <span class="ow">and</span> <span class="n">classes</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">!=</span> <span class="n">new_classes</span><span class="p">[</span><span class="n">key</span><span class="p">]:</span>
+                <span class="n">Logger</span><span class="p">()</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
+                    <span class="sa">f</span><span class="s2">"Class </span><span class="si">{</span><span class="n">key</span><span class="si">}</span><span class="s2"> found in both </span><span class="si">{</span><span class="n">classes</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="si">}</span><span class="s2"> and "</span>
+                    <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">new_classes</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="si">}</span><span class="s2">. Keeping first instance. "</span>
+                    <span class="s2">"Consider renaming."</span>
+                <span class="p">)</span>
+        <span class="n">classes</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">new_classes</span><span class="p">)</span>
+
+    <span class="k">return</span> <span class="n">classes</span></div>
+
+
+
+<div class="viewcode-block" id="is_graphnet_module">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.is_graphnet_module">[docs]</a>
+<span class="k">def</span> <span class="nf">is_graphnet_module</span><span class="p">(</span><span class="n">obj</span><span class="p">:</span> <span class="n">types</span><span class="o">.</span><span class="n">ModuleType</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Return whether `obj` is a module in graphnet."""</span>
+    <span class="k">return</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">types</span><span class="o">.</span><span class="n">ModuleType</span><span class="p">)</span> <span class="ow">and</span> <span class="n">obj</span><span class="o">.</span><span class="vm">__name__</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span>
+        <span class="s2">"graphnet."</span>
+    <span class="p">)</span></div>
+
+
+
+<div class="viewcode-block" id="is_graphnet_class">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.is_graphnet_class">[docs]</a>
+<span class="k">def</span> <span class="nf">is_graphnet_class</span><span class="p">(</span><span class="n">obj</span><span class="p">:</span> <span class="nb">type</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Return whether `obj` is a class in graphnet."""</span>
+    <span class="k">return</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">type</span><span class="p">)</span> <span class="ow">and</span> <span class="n">obj</span><span class="o">.</span><span class="vm">__module__</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"graphnet."</span><span class="p">)</span></div>
+
+
+
+<div class="viewcode-block" id="get_graphnet_classes">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.get_graphnet_classes">[docs]</a>
+<span class="k">def</span> <span class="nf">get_graphnet_classes</span><span class="p">(</span><span class="n">module</span><span class="p">:</span> <span class="n">types</span><span class="o">.</span><span class="n">ModuleType</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">type</span><span class="p">]:</span>
+<span class="w">    </span><span class="sd">"""Return a lookup of all graphnet class names in `module`."""</span>
+    <span class="k">if</span> <span class="ow">not</span> <span class="n">is_graphnet_module</span><span class="p">(</span><span class="n">module</span><span class="p">):</span>
+        <span class="n">Logger</span><span class="p">()</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">module</span><span class="si">}</span><span class="s2"> is not a graphnet module"</span><span class="p">)</span>
+        <span class="k">return</span> <span class="p">{}</span>
+    <span class="n">classes</span> <span class="o">=</span> <span class="p">{</span>
+        <span class="n">key</span><span class="p">:</span> <span class="n">val</span>
+        <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
+        <span class="k">if</span> <span class="n">is_graphnet_class</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
+    <span class="p">}</span>
+    <span class="k">return</span> <span class="n">classes</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/utilities/config/training_config.html b/_modules/graphnet/utilities/config/training_config.html
new file mode 100644
index 000000000..add0468a2
--- /dev/null
+++ b/_modules/graphnet/utilities/config/training_config.html
@@ -0,0 +1,378 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.utilities.config.training_config &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" />
+    <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../../about.html" />
+    <link rel="index" title="Index" href="../../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/utilities/config/training_config" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.utilities.config.training_config </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../../"versions.json"",
+        target_loc = "../../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-utilities-config-training-config--page-root">Source code for graphnet.utilities.config.training_config</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Config classes for the `graphnet.training` module."""</span>
+
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Union</span>
+
+<span class="kn">from</span> <span class="nn">graphnet.utilities.config.base_config</span> <span class="kn">import</span> <span class="n">BaseConfig</span>
+
+
+<div class="viewcode-block" id="TrainingConfig">
+<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig">[docs]</a>
+<span class="k">class</span> <span class="nc">TrainingConfig</span><span class="p">(</span><span class="n">BaseConfig</span><span class="p">):</span>
+<span class="w">    </span><span class="sd">"""Configuration for all trainings."""</span>
+
+    <span class="c1"># Fields</span>
+    <span class="n">target</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span>
+    <span class="n">early_stopping_patience</span><span class="p">:</span> <span class="nb">int</span>
+    <span class="n">fit</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
+    <span class="n">dataloader</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/graphnet/utilities/filesys.html b/_modules/graphnet/utilities/filesys.html
index ef2600ad3..121c9bfe5 100644
--- a/_modules/graphnet/utilities/filesys.html
+++ b/_modules/graphnet/utilities/filesys.html
@@ -447,7 +447,7 @@ <h1 id="modules-graphnet-utilities-filesys--page-root">Source code for graphnet.
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/utilities/imports.html b/_modules/graphnet/utilities/imports.html
index 40699ad6e..80debdadf 100644
--- a/_modules/graphnet/utilities/imports.html
+++ b/_modules/graphnet/utilities/imports.html
@@ -420,7 +420,7 @@ <h1 id="modules-graphnet-utilities-imports--page-root">Source code for graphnet.
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/utilities/logging.html b/_modules/graphnet/utilities/logging.html
index cb014897d..aa1b8a091 100644
--- a/_modules/graphnet/utilities/logging.html
+++ b/_modules/graphnet/utilities/logging.html
@@ -630,7 +630,7 @@ <h1 id="modules-graphnet-utilities-logging--page-root">Source code for graphnet.
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/_modules/graphnet/utilities/maths.html b/_modules/graphnet/utilities/maths.html
new file mode 100644
index 000000000..607211db6
--- /dev/null
+++ b/_modules/graphnet/utilities/maths.html
@@ -0,0 +1,371 @@
+<!DOCTYPE html>
+
+<html lang="en" data-content_root="../../../">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+
+
+  <meta name="viewport" content="width=device-width,initial-scale=1">
+  <meta http-equiv="x-ua-compatible" content="ie=edge">
+  <meta name="lang:clipboard.copy" content="Copy to clipboard">
+  <meta name="lang:clipboard.copied" content="Copied to clipboard">
+  <meta name="lang:search.language" content="en">
+  <meta name="lang:search.pipeline.stopwords" content="True">
+  <meta name="lang:search.pipeline.trimmer" content="True">
+  <meta name="lang:search.result.none" content="No matching documents">
+  <meta name="lang:search.result.one" content="1 matching document">
+  <meta name="lang:search.result.other" content="# matching documents">
+  <meta name="lang:search.tokenizer" content="[\s\-]+">
+
+  
+    <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin>
+    <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet">
+
+    <style>
+      body,
+      input {
+        font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif
+      }
+
+      code,
+      kbd,
+      pre {
+        font-family: "Roboto Mono", "Courier New", Courier, monospace
+      }
+    </style>
+  
+
+  <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/>
+  <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/>
+  
+  <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/>
+  
+  <meta name="theme-color" content="#3f51b5">
+  <script src="../../../_static/javascripts/modernizr.js"></script>
+  
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script>
+<script>
+    window.dataLayer = window.dataLayer || [];
+
+    function gtag() {
+        dataLayer.push(arguments);
+    }
+
+    gtag('js', new Date());
+
+    gtag('config', 'UA-XXXXX');
+</script>
+  
+  
+    <title>graphnet.utilities.maths &#8212; graphnet  documentation</title>
+
+<style>
+  dt:target {
+    margin-top: 0;
+    padding-top: 0;
+  }
+
+
+  /*
+    .sig-prename {
+     display: none;
+  }
+  */
+
+  .py.class .sig-name,
+  .py.function .sig-name,
+  .py.method .sig-name,
+  .py.exception .sig-name {
+    color: #37474f;
+    font-feature-settings: "kern";
+    font-family: "Roboto Mono", "Courier New", Courier, monospace;
+    font-weight: 700;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.method .sig-object,
+  .py.exception .sig-object {
+    padding: 1ex;
+  }
+
+  .py.class .sig-object,
+  .py.function .sig-object,
+  .py.exception .sig-object {
+    border-top: 1px solid gray;
+  }
+  .py.method .sig-object {
+    border-top: 1px solid lightgray;
+  }
+
+  .py.class .sig-object,
+  .py.exception .sig-object {
+    background: rgba(0,0,0,0.06);
+  }
+  .py.function .sig-object,
+  .py.method .sig-object {
+    background: rgba(0,0,0,0.03);
+  }
+
+  #eu-emblem {
+    margin: 0;
+  }
+
+  #eu-emblem figcaption {
+    display: none;
+  }
+
+</style>
+    <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" />
+    <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" />
+    <script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
+    <script src="../../../_static/doctools.js?v=888ff710"></script>
+    <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <link rel="icon" href="../../../_static/favicon.svg"/>
+    <link rel="author" title="About these documents" href="../../../about.html" />
+    <link rel="index" title="Index" href="../../../genindex.html" />
+    <link rel="search" title="Search" href="../../../search.html" />
+  
+   
+
+  </head>
+  <body dir=ltr
+        data-md-color-primary=indigo data-md-color-accent=blue>
+  
+  <svg class="md-svg">
+    <defs data-children-count="0">
+      
+      <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg>
+      
+    </defs>
+  </svg>
+  
+  <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer">
+  <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search">
+  <label class="md-overlay" data-md-component="overlay" for="__drawer"></label>
+  <a href="#_modules/graphnet/utilities/maths" tabindex="1" class="md-skip"> Skip to content </a>
+  <header class="md-header" data-md-component="header">
+  <nav class="md-header-nav md-grid">
+    <div class="md-flex navheader">
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <a href="../../../index.html" title="graphnet  documentation"
+           class="md-header-nav__button md-logo">
+          
+            &nbsp;
+          
+        </a>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label>
+      </div>
+      <div class="md-flex__cell md-flex__cell--stretch">
+        <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title">
+          <span class="md-header-nav__topic">GraphNeT</span>
+          <span class="md-header-nav__topic"> graphnet.utilities.maths </span>
+        </div>
+      </div>
+      <div class="md-flex__cell md-flex__cell--shrink">
+        <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label>
+        
+<div class="md-search" data-md-component="search" role="dialog">
+  <label class="md-search__overlay" for="__search"></label>
+  <div class="md-search__inner" role="search">
+    <form class="md-search__form" action="../../../search.html" method="get" name="search">
+      <input type="text" class="md-search__input" name="q" placeholder=""Search""
+             autocapitalize="off" autocomplete="off" spellcheck="false"
+             data-md-component="query" data-md-state="active">
+      <label class="md-icon md-search__icon" for="__search"></label>
+      <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1">
+        &#xE5CD;
+      </button>
+    </form>
+    <div class="md-search__output">
+      <div class="md-search__scrollwrap" data-md-scrollfix>
+        <div class="md-search-result" data-md-component="result">
+          <div class="md-search-result__meta">
+            Type to start searching
+          </div>
+          <ol class="md-search-result__list"></ol>
+        </div>
+      </div>
+    </div>
+  </div>
+</div>
+
+      </div>
+      
+        <div class="md-flex__cell md-flex__cell--shrink">
+          <div class="md-header-nav__source">
+            <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+          </div>
+        </div>
+      
+      
+  
+  <script src="../../../_static/javascripts/version_dropdown.js"></script>
+  <script>
+    var json_loc = "../../../"versions.json"",
+        target_loc = "../../../../",
+        text = "Versions";
+    $( document ).ready( add_version_dropdown(json_loc, target_loc, text));
+  </script>
+  
+
+    </div>
+  </nav>
+</header>
+
+  
+  <div class="md-container">
+    
+    
+    
+  <nav class="md-tabs" data-md-component="tabs">
+    <div class="md-tabs__inner md-grid">
+      <ul class="md-tabs__list">
+            
+            <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li>
+          <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li>
+      </ul>
+    </div>
+  </nav>
+    <main class="md-main">
+      <div class="md-main__inner md-grid" data-md-component="container">
+        
+          <div class="md-sidebar md-sidebar--primary" data-md-component="navigation">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                <nav class="md-nav md-nav--primary" data-md-level="0">
+  <label class="md-nav__title md-nav__title--site" for="__drawer">
+    <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo">
+      
+        <img src="../../../_static/" alt=" logo" width="48" height="48">
+      
+    </a>
+    <a href="../../../index.html"
+       title="graphnet documentation">GraphNeT</a>
+  </label>
+    <div class="md-nav__source">
+      <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github">
+
+    <div class="md-source__icon">
+      <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28">
+        <use xlink:href="#__github" width="24" height="24"></use>
+      </svg>
+    </div>
+  
+  <div class="md-source__repository">
+    GraphNeT
+  </div>
+</a>
+    </div>
+  
+  
+
+  
+  <ul class="md-nav__list">
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../install.html" class="md-nav__link">Install</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../contribute.html" class="md-nav__link">Contribute</a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="../../../api/graphnet.html" class="md-nav__link">API</a>
+      
+    
+    </li>
+  </ul>
+  
+
+</nav>
+              </div>
+            </div>
+          </div>
+          <div class="md-sidebar md-sidebar--secondary" data-md-component="toc">
+            <div class="md-sidebar__scrollwrap">
+              <div class="md-sidebar__inner">
+                
+<nav class="md-nav md-nav--secondary">
+  <ul class="md-nav__list" data-md-scrollfix="">
+  </ul>
+</nav>
+              </div>
+            </div>
+          </div>
+        
+        <div class="md-content">
+          <article class="md-content__inner md-typeset" role="main">
+            
+  <h1 id="modules-graphnet-utilities-maths--page-root">Source code for graphnet.utilities.maths</h1><div class="highlight"><pre>
+<span></span><span class="sd">"""Collection of assorted "maths-like" functions."""</span>
+
+<span class="kn">import</span> <span class="nn">torch</span>
+
+
+<div class="viewcode-block" id="eps_like">
+<a class="viewcode-back" href="../../../api/graphnet.utilities.maths.html#graphnet.utilities.maths.eps_like">[docs]</a>
+<span class="k">def</span> <span class="nf">eps_like</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
+<span class="w">    </span><span class="sd">"""Return `eps` matching `tensor`'s dtype."""</span>
+    <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">finfo</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">eps</span></div>
+
+</pre></div>
+
+          </article>
+        </div>
+      </div>
+    </main>
+  </div>
+  <footer class="md-footer">
+    <div class="md-footer-nav">
+      <nav class="md-footer-nav__inner md-grid">
+          
+          
+        </a>
+        
+      </nav>
+    </div>
+    <div class="md-footer-meta md-typeset">
+      <div class="md-footer-meta__inner md-grid">
+        <div class="md-footer-copyright">
+          <div class="md-footer-copyright__highlight">
+              &#169; Copyright 2021-2023, GraphNeT team.
+              
+          </div>
+            Created using
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
+             and
+            <a href="https://github.com/bashtage/sphinx-material/">Material for
+              Sphinx</a>
+        </div>
+      </div>
+    </div>
+  </footer>
+  <script src="../../../_static/javascripts/application.js"></script>
+  <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script>
+  </body>
+</html>
\ No newline at end of file
diff --git a/_modules/index.html b/_modules/index.html
index 89db4987e..120145b62 100644
--- a/_modules/index.html
+++ b/_modules/index.html
@@ -323,6 +323,11 @@
   <h1 id="modules-index--page-root">All modules for which code is available</h1>
 <ul><li><a href="graphnet/data/constants.html">graphnet.data.constants</a></li>
 <li><a href="graphnet/data/dataconverter.html">graphnet.data.dataconverter</a></li>
+<li><a href="graphnet/data/dataloader.html">graphnet.data.dataloader</a></li>
+<li><a href="graphnet/data/dataset/dataset.html">graphnet.data.dataset.dataset</a></li>
+<li><a href="graphnet/data/dataset/parquet/parquet_dataset.html">graphnet.data.dataset.parquet.parquet_dataset</a></li>
+<li><a href="graphnet/data/dataset/sqlite/sqlite_dataset.html">graphnet.data.dataset.sqlite.sqlite_dataset</a></li>
+<li><a href="graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html">graphnet.data.dataset.sqlite.sqlite_dataset_perturbed</a></li>
 <li><a href="graphnet/data/extractors/i3extractor.html">graphnet.data.extractors.i3extractor</a></li>
 <li><a href="graphnet/data/extractors/i3featureextractor.html">graphnet.data.extractors.i3featureextractor</a></li>
 <li><a href="graphnet/data/extractors/i3genericextractor.html">graphnet.data.extractors.i3genericextractor</a></li>
@@ -339,18 +344,52 @@ <h1 id="modules-index--page-root">All modules for which code is available</h1>
 <li><a href="graphnet/data/extractors/utilities/frames.html">graphnet.data.extractors.utilities.frames</a></li>
 <li><a href="graphnet/data/extractors/utilities/types.html">graphnet.data.extractors.utilities.types</a></li>
 <li><a href="graphnet/data/parquet/parquet_dataconverter.html">graphnet.data.parquet.parquet_dataconverter</a></li>
+<li><a href="graphnet/data/pipeline.html">graphnet.data.pipeline</a></li>
 <li><a href="graphnet/data/sqlite/sqlite_dataconverter.html">graphnet.data.sqlite.sqlite_dataconverter</a></li>
 <li><a href="graphnet/data/sqlite/sqlite_utilities.html">graphnet.data.sqlite.sqlite_utilities</a></li>
 <li><a href="graphnet/data/utilities/parquet_to_sqlite.html">graphnet.data.utilities.parquet_to_sqlite</a></li>
 <li><a href="graphnet/data/utilities/random.html">graphnet.data.utilities.random</a></li>
 <li><a href="graphnet/data/utilities/string_selection_resolver.html">graphnet.data.utilities.string_selection_resolver</a></li>
+<li><a href="graphnet/deployment/i3modules/graphnet_module.html">graphnet.deployment.i3modules.graphnet_module</a></li>
+<li><a href="graphnet/models/coarsening.html">graphnet.models.coarsening</a></li>
+<li><a href="graphnet/models/components/layers.html">graphnet.models.components.layers</a></li>
+<li><a href="graphnet/models/components/pool.html">graphnet.models.components.pool</a></li>
+<li><a href="graphnet/models/detector/detector.html">graphnet.models.detector.detector</a></li>
+<li><a href="graphnet/models/detector/icecube.html">graphnet.models.detector.icecube</a></li>
+<li><a href="graphnet/models/detector/prometheus.html">graphnet.models.detector.prometheus</a></li>
+<li><a href="graphnet/models/gnn/convnet.html">graphnet.models.gnn.convnet</a></li>
+<li><a href="graphnet/models/gnn/dynedge.html">graphnet.models.gnn.dynedge</a></li>
+<li><a href="graphnet/models/gnn/dynedge_jinst.html">graphnet.models.gnn.dynedge_jinst</a></li>
+<li><a href="graphnet/models/gnn/dynedge_kaggle_tito.html">graphnet.models.gnn.dynedge_kaggle_tito</a></li>
+<li><a href="graphnet/models/gnn/gnn.html">graphnet.models.gnn.gnn</a></li>
+<li><a href="graphnet/models/graphs/edges/edges.html">graphnet.models.graphs.edges.edges</a></li>
+<li><a href="graphnet/models/graphs/graph_definition.html">graphnet.models.graphs.graph_definition</a></li>
+<li><a href="graphnet/models/graphs/graphs.html">graphnet.models.graphs.graphs</a></li>
+<li><a href="graphnet/models/graphs/nodes/nodes.html">graphnet.models.graphs.nodes.nodes</a></li>
+<li><a href="graphnet/models/model.html">graphnet.models.model</a></li>
+<li><a href="graphnet/models/standard_model.html">graphnet.models.standard_model</a></li>
+<li><a href="graphnet/models/task/classification.html">graphnet.models.task.classification</a></li>
+<li><a href="graphnet/models/task/reconstruction.html">graphnet.models.task.reconstruction</a></li>
+<li><a href="graphnet/models/task/task.html">graphnet.models.task.task</a></li>
+<li><a href="graphnet/models/utils.html">graphnet.models.utils</a></li>
 <li><a href="graphnet/pisa/fitting.html">graphnet.pisa.fitting</a></li>
 <li><a href="graphnet/pisa/plotting.html">graphnet.pisa.plotting</a></li>
+<li><a href="graphnet/training/callbacks.html">graphnet.training.callbacks</a></li>
+<li><a href="graphnet/training/labels.html">graphnet.training.labels</a></li>
+<li><a href="graphnet/training/loss_functions.html">graphnet.training.loss_functions</a></li>
+<li><a href="graphnet/training/utils.html">graphnet.training.utils</a></li>
 <li><a href="graphnet/training/weight_fitting.html">graphnet.training.weight_fitting</a></li>
 <li><a href="graphnet/utilities/argparse.html">graphnet.utilities.argparse</a></li>
+<li><a href="graphnet/utilities/config/base_config.html">graphnet.utilities.config.base_config</a></li>
+<li><a href="graphnet/utilities/config/configurable.html">graphnet.utilities.config.configurable</a></li>
+<li><a href="graphnet/utilities/config/dataset_config.html">graphnet.utilities.config.dataset_config</a></li>
+<li><a href="graphnet/utilities/config/model_config.html">graphnet.utilities.config.model_config</a></li>
+<li><a href="graphnet/utilities/config/parsing.html">graphnet.utilities.config.parsing</a></li>
+<li><a href="graphnet/utilities/config/training_config.html">graphnet.utilities.config.training_config</a></li>
 <li><a href="graphnet/utilities/filesys.html">graphnet.utilities.filesys</a></li>
 <li><a href="graphnet/utilities/imports.html">graphnet.utilities.imports</a></li>
 <li><a href="graphnet/utilities/logging.html">graphnet.utilities.logging</a></li>
+<li><a href="graphnet/utilities/maths.html">graphnet.utilities.maths</a></li>
 </ul>
 
           </article>
@@ -375,7 +414,7 @@ <h1 id="modules-index--page-root">All modules for which code is available</h1>
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/about.html b/about.html
index 3ebaa11ee..6422d9cf3 100644
--- a/about.html
+++ b/about.html
@@ -392,7 +392,7 @@ <h2 id="acknowledgements">Acknowledgements<a class="headerlink" href="#acknowled
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.constants.html b/api/graphnet.constants.html
index 5cc170ec7..0a215dbf4 100644
--- a/api/graphnet.constants.html
+++ b/api/graphnet.constants.html
@@ -422,7 +422,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.constants.html b/api/graphnet.data.constants.html
index 36edc3f62..32d33f303 100644
--- a/api/graphnet.data.constants.html
+++ b/api/graphnet.data.constants.html
@@ -700,7 +700,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.dataconverter.html b/api/graphnet.data.dataconverter.html
index 61d9300a1..40c210634 100644
--- a/api/graphnet.data.dataconverter.html
+++ b/api/graphnet.data.dataconverter.html
@@ -820,7 +820,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.dataloader.html b/api/graphnet.data.dataloader.html
index 8ad405a2b..98bb280f5 100644
--- a/api/graphnet.data.dataloader.html
+++ b/api/graphnet.data.dataloader.html
@@ -365,11 +365,54 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-data-dataloader--page-root" class="md-nav__link">dataloader</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataloader.collate_fn" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataloader.do_shuffle" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">do_shuffle()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataloader.DataLoader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataloader.DataLoader.from_dataset_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_dataset_config()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataloader.collate_fn" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataloader.do_shuffle" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">do_shuffle()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataloader.DataLoader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataloader.DataLoader.from_dataset_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_dataset_config()</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -436,7 +479,22 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-data-dataloader--page-root" class="md-nav__link">dataloader</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataloader.collate_fn" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataloader.do_shuffle" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">do_shuffle()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataloader.DataLoader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataloader.DataLoader.from_dataset_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_dataset_config()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -446,8 +504,73 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="dataloader">
-<h1 id="api-graphnet-data-dataloader--page-root">dataloader<a class="headerlink" href="#api-graphnet-data-dataloader--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.data.dataloader">
+<span id="dataloader"></span><h1 id="api-graphnet-data-dataloader--page-root">dataloader<a class="headerlink" href="#api-graphnet-data-dataloader--page-root" title="Link to this heading">¶</a></h1>
+<p>Base <cite>Dataloader</cite> class(es) used in <cite>graphnet</cite>.</p>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.data.dataloader.collate_fn">
+<span class="sig-prename descclassname"><span class="pre">graphnet.data.dataloader.</span></span><span class="sig-name descname"><span class="pre">collate_fn</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">graphs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataloader.html#collate_fn"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataloader.collate_fn" title="Link to this definition">¶</a></dt>
+<dd><p>Remove graphs with less than two DOM hits.</p>
+<p>Should not occur in “production.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Batch</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>graphs</strong> (<em>List</em><em>[</em><em>Data</em><em>]</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.data.dataloader.do_shuffle">
+<span class="sig-prename descclassname"><span class="pre">graphnet.data.dataloader.</span></span><span class="sig-name descname"><span class="pre">do_shuffle</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">selection_name</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataloader.html#do_shuffle"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataloader.do_shuffle" title="Link to this definition">¶</a></dt>
+<dd><p>Check whether to shuffle selection with name <cite>selection_name</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>selection_name</strong> (<em>str</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.data.dataloader.DataLoader">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.dataloader.</span></span><span class="sig-name descname"><span class="pre">DataLoader</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">dataset</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shuffle</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_workers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">persistent_workers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">collate_fn=&lt;function</span> <span class="pre">collate_fn&gt;</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prefetch_factor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">**kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataloader.html#DataLoader"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataloader.DataLoader" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">DataLoader</span></code></p>
+<p>Class for loading data from a <cite>Dataset</cite>.</p>
+<p>Construct <cite>DataLoader</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>dataset</strong> (<em>Dataset</em><em>[</em><em>T_co</em><em>]</em>) – </p></li>
+<li><p><strong>batch_size</strong> (<em>int</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>shuffle</strong> (<em>bool</em>) – </p></li>
+<li><p><strong>num_workers</strong> (<em>int</em>) – </p></li>
+<li><p><strong>persistent_workers</strong> (<em>bool</em>) – </p></li>
+<li><p><strong>collate_fn</strong> (<em>Callable</em>) – </p></li>
+<li><p><strong>prefetch_factor</strong> (<em>int</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.data.dataloader.DataLoader.from_dataset_config">
+<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">from_dataset_config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">config</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataloader.html#DataLoader.from_dataset_config"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataloader.DataLoader.from_dataset_config" title="Link to this definition">¶</a></dt>
+<dd><p>Construct <cite>DataLoader`s based on selections in `DatasetConfig</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<a class="reference internal" href="#graphnet.data.dataloader.DataLoader" title="graphnet.data.dataloader.DataLoader"><code class="xref py py-class docutils literal notranslate"><span class="pre">DataLoader</span></code></a>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <a class="reference internal" href="#graphnet.data.dataloader.DataLoader" title="graphnet.data.dataloader.DataLoader"><code class="xref py py-class docutils literal notranslate"><span class="pre">DataLoader</span></code></a>]]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>config</strong> (<a class="reference internal" href="graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig" title="graphnet.utilities.config.dataset_config.DatasetConfig"><em>DatasetConfig</em></a>) – </p></li>
+<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -497,7 +620,7 @@ <h1 id="api-graphnet-data-dataloader--page-root">dataloader<a class="headerlink"
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.dataset.dataset.html b/api/graphnet.data.dataset.dataset.html
index 0750d2bf5..ac7ed7ec2 100644
--- a/api/graphnet.data.dataset.dataset.html
+++ b/api/graphnet.data.dataset.dataset.html
@@ -336,11 +336,117 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-data-dataset-dataset--page-root" class="md-nav__link">dataset</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.ColumnMissingException" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ColumnMissingException</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.load_module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load_module()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.parse_graph_definition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">parse_graph_definition()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Dataset</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.concatenate" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">concatenate()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.path" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">path</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth_table</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.add_label" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">add_label()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.EnsembleDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnsembleDataset</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.dataset.ColumnMissingException" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ColumnMissingException</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.dataset.load_module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load_module()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.dataset.parse_graph_definition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">parse_graph_definition()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.dataset.Dataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Dataset</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.dataset.Dataset.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.dataset.Dataset.concatenate" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">concatenate()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.dataset.Dataset.path" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">path</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.dataset.Dataset.truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth_table</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.dataset.Dataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.dataset.Dataset.add_label" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">add_label()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.dataset.EnsembleDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnsembleDataset</span></code></a>
+      
+    
+    </li></ul>
+    
     </li></ul>
     
     </li>
@@ -458,7 +564,36 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-data-dataset-dataset--page-root" class="md-nav__link">dataset</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.ColumnMissingException" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ColumnMissingException</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.load_module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load_module()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.parse_graph_definition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">parse_graph_definition()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Dataset</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.concatenate" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">concatenate()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.path" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">path</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth_table</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.add_label" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">add_label()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.EnsembleDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnsembleDataset</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -468,8 +603,197 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="dataset">
-<h1 id="api-graphnet-data-dataset-dataset--page-root">dataset<a class="headerlink" href="#api-graphnet-data-dataset-dataset--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.data.dataset.dataset">
+<span id="dataset"></span><h1 id="api-graphnet-data-dataset-dataset--page-root">dataset<a class="headerlink" href="#api-graphnet-data-dataset-dataset--page-root" title="Link to this heading">¶</a></h1>
+<p>Base <a class="reference internal" href="#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code></a> class(es) used in GraphNeT.</p>
+<dl class="py exception">
+<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.ColumnMissingException">
+<em class="property"><span class="pre">exception</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.dataset.</span></span><span class="sig-name descname"><span class="pre">ColumnMissingException</span></span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#ColumnMissingException"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.ColumnMissingException" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">Exception</span></code></p>
+<p>Exception to indicate a missing column in a dataset.</p>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.load_module">
+<span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.dataset.</span></span><span class="sig-name descname"><span class="pre">load_module</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">class_name</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#load_module"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.load_module" title="Link to this definition">¶</a></dt>
+<dd><p>Load graphnet module from string name.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>class_name</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – name of class</p>
+</dd>
+<dt class="field-even">Return type<span class="colon">:</span></dt>
+<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Type</span></code></p>
+</dd>
+<dt class="field-odd">Returns<span class="colon">:</span></dt>
+<dd class="field-odd"><p>graphnet module.</p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.parse_graph_definition">
+<span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.dataset.</span></span><span class="sig-name descname"><span class="pre">parse_graph_definition</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cfg</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#parse_graph_definition"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.parse_graph_definition" title="Link to this definition">¶</a></dt>
+<dd><p>Construct GraphDefinition from DatasetConfig.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>cfg</strong> (<em>dict</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.Dataset">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.dataset.</span></span><span class="sig-name descname"><span class="pre">Dataset</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemaps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">index_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">string_selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_default_value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">seed</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#Dataset"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.Dataset" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.utilities.logging.html#graphnet.utilities.logging.Logger" title="graphnet.utilities.logging.Logger"><code class="xref py py-class docutils literal notranslate"><span class="pre">Logger</span></code></a>, <a class="reference internal" href="graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable" title="graphnet.utilities.config.configurable.Configurable"><code class="xref py py-class docutils literal notranslate"><span class="pre">Configurable</span></code></a>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">ABC</span></code></p>
+<p>Base Dataset class for reading from any intermediate file format.</p>
+<p>Construct Dataset.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>path</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Path to the file(s) from which this <cite>Dataset</cite> should read.</p></li>
+<li><p><strong>pulsemaps</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Name(s) of the pulse map series that should be used to
+construct the nodes on the individual graph objects, and their
+features. Multiple pulse series maps can be used, e.g., when
+different DOM types are stored in different maps.</p></li>
+<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of columns in the input files that should be used as
+node features on the graph objects.</p></li>
+<li><p><strong>truth</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of event-level columns in the input files that should
+be used added as attributes on the  graph objects.</p></li>
+<li><p><strong>node_truth</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – List of node-level columns in the input files that
+should be used added as attributes on the graph objects.</p></li>
+<li><p><strong>index_column</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'event_no'</span></code>) – Name of the column in the input files that contains
+unique indicies to identify and map events across tables.</p></li>
+<li><p><strong>truth_table</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'truth'</span></code>) – Name of the table containing event-level truth
+information.</p></li>
+<li><p><strong>node_truth_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing node-level truth
+information.</p></li>
+<li><p><strong>string_selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Subset of strings for which data should be read
+and used to construct graph objects. Defaults to None, meaning
+all strings for which data exists are used.</p></li>
+<li><p><strong>selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The events that should be read. This can be given either
+as list of indicies (in <cite>index_column</cite>); or a string-based
+selection used to query the <cite>Dataset</cite> for events passing the
+selection. Defaults to None, meaning that all events in the
+input files are read.</p></li>
+<li><p><strong>dtype</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">dtype</span></code>, default: <code class="docutils literal notranslate"><span class="pre">torch.float32</span></code>) – Type of the feature tensor on the graph objects returned.</p></li>
+<li><p><strong>loss_weight_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing per-event loss
+weights.</p></li>
+<li><p><strong>loss_weight_column</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the column in <cite>loss_weight_table</cite>
+containing per-event loss weights. This is also the name of the
+corresponding attribute assigned to the graph object.</p></li>
+<li><p><strong>loss_weight_default_value</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Default per-event loss weight.
+NOTE: This default value is only applied when
+<cite>loss_weight_table</cite> and <cite>loss_weight_column</cite> are specified, and
+in this case to events with no value in the corresponding
+table/column. That is, if no per-event loss weight table/column
+is provided, this value is ignored. Defaults to None.</p></li>
+<li><p><strong>seed</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Random number generator seed, used for selecting a random
+subset of events when resolving a string-based selection (e.g.,
+<cite>“10000 random events ~ event_no % 5 &gt; 0”</cite> or <cite>“20% random
+events ~ event_no % 5 &gt; 0”</cite>).</p></li>
+<li><p><strong>graph_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a>) – Method that defines the graph representation.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.Dataset.from_config">
+<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">from_config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">source</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#Dataset.from_config"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.Dataset.from_config" title="Link to this definition">¶</a></dt>
+<dd><p>Construct <cite>Dataset</cite> instance from <cite>source</cite> configuration.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<a class="reference internal" href="#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code></a>, <a class="reference internal" href="#graphnet.data.dataset.dataset.EnsembleDataset" title="graphnet.data.dataset.dataset.EnsembleDataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">EnsembleDataset</span></code></a>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <a class="reference internal" href="#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code></a>], <code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <a class="reference internal" href="#graphnet.data.dataset.dataset.EnsembleDataset" title="graphnet.data.dataset.dataset.EnsembleDataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">EnsembleDataset</span></code></a>]]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>source</strong> (<a class="reference internal" href="graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig" title="graphnet.utilities.config.dataset_config.DatasetConfig"><em>DatasetConfig</em></a><em> | </em><em>str</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.Dataset.concatenate">
+<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">concatenate</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">datasets</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#Dataset.concatenate"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.Dataset.concatenate" title="Link to this definition">¶</a></dt>
+<dd><p>Concatenate multiple <a href="#id1"><span class="problematic" id="id2">`</span></a>Dataset`s into one instance.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><a class="reference internal" href="#graphnet.data.dataset.dataset.EnsembleDataset" title="graphnet.data.dataset.dataset.EnsembleDataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">EnsembleDataset</span></code></a></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>datasets</strong> (<em>List</em><em>[</em><a class="reference internal" href="#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><em>Dataset</em></a><em>]</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.Dataset.path">
+<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">path</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">str</span><span class="w"> </span><span class="p"><span class="pre">|</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#graphnet.data.dataset.dataset.Dataset.path" title="Link to this definition">¶</a></dt>
+<dd><p>Path to the file(s) from which this <cite>Dataset</cite> reads.</p>
+</dd></dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.Dataset.truth_table">
+<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">truth_table</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">str</span></em><a class="headerlink" href="#graphnet.data.dataset.dataset.Dataset.truth_table" title="Link to this definition">¶</a></dt>
+<dd><p>Name of the table containing event-level truth information.</p>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.Dataset.query_table">
+<em class="property"><span class="pre">abstract</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">query_table</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">sequential_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#Dataset.query_table"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.Dataset.query_table" title="Link to this definition">¶</a></dt>
+<dd><p>Query a table at a specific index, optionally with some selection.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>table</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – Table to be queried.</p></li>
+<li><p><strong>columns</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – Columns to read out.</p></li>
+<li><p><strong>sequential_index</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Sequentially numbered index
+(i.e. in [0,len(self))) of the event to query. This _may_
+differ from the indexation used in <cite>self._indices</cite>. If no value
+is provided, the entire column is returned.</p></li>
+<li><p><strong>selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Selection to be imposed before reading out data.
+Defaults to None.</p></li>
+</ul>
+</dd>
+<dt class="field-even">Return type<span class="colon">:</span></dt>
+<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">...</span></code>]]</p>
+</dd>
+<dt class="field-odd">Returns<span class="colon">:</span></dt>
+<dd class="field-odd"><p><dl class="simple">
+<dt>List of tuples containing the values in <cite>columns</cite>. If the <cite>table</cite></dt><dd><p>contains only scalar data for <cite>columns</cite>, a list of length 1 is
+returned</p>
+</dd>
+</dl>
+</p>
+</dd>
+<dt class="field-even">Raises<span class="colon">:</span></dt>
+<dd class="field-even"><p><a class="reference internal" href="#graphnet.data.dataset.dataset.ColumnMissingException" title="graphnet.data.dataset.dataset.ColumnMissingException"><strong>ColumnMissingException</strong></a> – If one or more element in <cite>columns</cite> is not
+    present in <cite>table</cite>.</p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.Dataset.add_label">
+<span class="sig-name descname"><span class="pre">add_label</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">fn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">key</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#Dataset.add_label"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.Dataset.add_label" title="Link to this definition">¶</a></dt>
+<dd><p>Add custom graph label define using function <cite>fn</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>fn</strong> (<em>Callable</em><em>[</em><em>[</em><em>Data</em><em>]</em><em>, </em><em>Any</em><em>]</em>) – </p></li>
+<li><p><strong>key</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.EnsembleDataset">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.dataset.</span></span><span class="sig-name descname"><span class="pre">EnsembleDataset</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">datasets</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#EnsembleDataset"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.EnsembleDataset" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">ConcatDataset</span></code></p>
+<p>Construct a single dataset from a collection of datasets.</p>
+<p>Construct a single dataset from a collection of datasets.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>datasets</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Iterable</span></code>[<a class="reference internal" href="#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code></a>]) – A collection of Datasets</p>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -519,7 +843,7 @@ <h1 id="api-graphnet-data-dataset-dataset--page-root">dataset<a class="headerlin
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.dataset.html b/api/graphnet.data.dataset.html
index 9c721d294..a4b095f8e 100644
--- a/api/graphnet.data.dataset.html
+++ b/api/graphnet.data.dataset.html
@@ -467,8 +467,9 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="dataset">
-<h1 id="api-graphnet-data-dataset--page-root">dataset<a class="headerlink" href="#api-graphnet-data-dataset--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.data.dataset">
+<span id="dataset"></span><h1 id="api-graphnet-data-dataset--page-root">dataset<a class="headerlink" href="#api-graphnet-data-dataset--page-root" title="Link to this heading">¶</a></h1>
+<p>Dataset classes for training in GraphNeT.</p>
 <p><h2> Subpackages </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
@@ -486,7 +487,14 @@ <h1 id="api-graphnet-data-dataset--page-root">dataset<a class="headerlink" href=
 <p><h2> Submodules </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.dataset.html">dataset</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.dataset.html">dataset</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.ColumnMissingException"><code class="docutils literal notranslate"><span class="pre">ColumnMissingException</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.load_module"><code class="docutils literal notranslate"><span class="pre">load_module()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.parse_graph_definition"><code class="docutils literal notranslate"><span class="pre">parse_graph_definition()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset"><code class="docutils literal notranslate"><span class="pre">Dataset</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.EnsembleDataset"><code class="docutils literal notranslate"><span class="pre">EnsembleDataset</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -538,7 +546,7 @@ <h1 id="api-graphnet-data-dataset--page-root">dataset<a class="headerlink" href=
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.dataset.parquet.html b/api/graphnet.data.dataset.parquet.html
index 873efa2c0..492684fad 100644
--- a/api/graphnet.data.dataset.parquet.html
+++ b/api/graphnet.data.dataset.parquet.html
@@ -475,12 +475,16 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="parquet">
-<h1 id="api-graphnet-data-dataset-parquet--page-root">parquet<a class="headerlink" href="#api-graphnet-data-dataset-parquet--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.data.dataset.parquet">
+<span id="parquet"></span><h1 id="api-graphnet-data-dataset-parquet--page-root">parquet<a class="headerlink" href="#api-graphnet-data-dataset-parquet--page-root" title="Link to this heading">¶</a></h1>
+<p>Datasets using parquet backend.</p>
 <p><h2> Submodules </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.parquet.parquet_dataset.html">parquet_dataset</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.parquet.parquet_dataset.html">parquet_dataset</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.parquet.parquet_dataset.html#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset"><code class="docutils literal notranslate"><span class="pre">ParquetDataset</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -532,7 +536,7 @@ <h1 id="api-graphnet-data-dataset-parquet--page-root">parquet<a class="headerlin
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.dataset.parquet.parquet_dataset.html b/api/graphnet.data.dataset.parquet.parquet_dataset.html
index 67cfc6f08..3598486cf 100644
--- a/api/graphnet.data.dataset.parquet.parquet_dataset.html
+++ b/api/graphnet.data.dataset.parquet.parquet_dataset.html
@@ -328,11 +328,36 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-data-dataset-parquet-parquet-dataset--page-root" class="md-nav__link">parquet_dataset</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ParquetDataset</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ParquetDataset</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li></ul>
     
     </li>
@@ -466,7 +491,18 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-data-dataset-parquet-parquet-dataset--page-root" class="md-nav__link">parquet_dataset</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ParquetDataset</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -476,8 +512,82 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="parquet-dataset">
-<h1 id="api-graphnet-data-dataset-parquet-parquet-dataset--page-root">parquet_dataset<a class="headerlink" href="#api-graphnet-data-dataset-parquet-parquet-dataset--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.data.dataset.parquet.parquet_dataset">
+<span id="parquet-dataset"></span><h1 id="api-graphnet-data-dataset-parquet-parquet-dataset--page-root">parquet_dataset<a class="headerlink" href="#api-graphnet-data-dataset-parquet-parquet-dataset--page-root" title="Link to this heading">¶</a></h1>
+<p><cite>Dataset</cite> class(es) for reading from Parquet files.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.parquet.parquet_dataset.</span></span><span class="sig-name descname"><span class="pre">ParquetDataset</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemaps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">index_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">string_selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_default_value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">seed</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/parquet/parquet_dataset.html#ParquetDataset"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code></a></p>
+<p>Pytorch dataset for reading from Parquet files.</p>
+<p>Construct Dataset.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>path</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Path to the file(s) from which this <cite>Dataset</cite> should read.</p></li>
+<li><p><strong>pulsemaps</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Name(s) of the pulse map series that should be used to
+construct the nodes on the individual graph objects, and their
+features. Multiple pulse series maps can be used, e.g., when
+different DOM types are stored in different maps.</p></li>
+<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of columns in the input files that should be used as
+node features on the graph objects.</p></li>
+<li><p><strong>truth</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of event-level columns in the input files that should
+be used added as attributes on the  graph objects.</p></li>
+<li><p><strong>node_truth</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – List of node-level columns in the input files that
+should be used added as attributes on the graph objects.</p></li>
+<li><p><strong>index_column</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'event_no'</span></code>) – Name of the column in the input files that contains
+unique indicies to identify and map events across tables.</p></li>
+<li><p><strong>truth_table</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'truth'</span></code>) – Name of the table containing event-level truth
+information.</p></li>
+<li><p><strong>node_truth_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing node-level truth
+information.</p></li>
+<li><p><strong>string_selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Subset of strings for which data should be read
+and used to construct graph objects. Defaults to None, meaning
+all strings for which data exists are used.</p></li>
+<li><p><strong>selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The events that should be read. This can be given either
+as list of indicies (in <cite>index_column</cite>); or a string-based
+selection used to query the <cite>Dataset</cite> for events passing the
+selection. Defaults to None, meaning that all events in the
+input files are read.</p></li>
+<li><p><strong>dtype</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">dtype</span></code>, default: <code class="docutils literal notranslate"><span class="pre">torch.float32</span></code>) – Type of the feature tensor on the graph objects returned.</p></li>
+<li><p><strong>loss_weight_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing per-event loss
+weights.</p></li>
+<li><p><strong>loss_weight_column</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the column in <cite>loss_weight_table</cite>
+containing per-event loss weights. This is also the name of the
+corresponding attribute assigned to the graph object.</p></li>
+<li><p><strong>loss_weight_default_value</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Default per-event loss weight.
+NOTE: This default value is only applied when
+<cite>loss_weight_table</cite> and <cite>loss_weight_column</cite> are specified, and
+in this case to events with no value in the corresponding
+table/column. That is, if no per-event loss weight table/column
+is provided, this value is ignored. Defaults to None.</p></li>
+<li><p><strong>seed</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Random number generator seed, used for selecting a random
+subset of events when resolving a string-based selection (e.g.,
+<cite>“10000 random events ~ event_no % 5 &gt; 0”</cite> or <cite>“20% random
+events ~ event_no % 5 &gt; 0”</cite>).</p></li>
+<li><p><strong>graph_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a>) – Method that defines the graph representation.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table">
+<span class="sig-name descname"><span class="pre">query_table</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">sequential_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/parquet/parquet_dataset.html#ParquetDataset.query_table"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table" title="Link to this definition">¶</a></dt>
+<dd><p>Query table at a specific index, optionally with some selection.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">...</span></code>]]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>table</strong> (<em>str</em>) – </p></li>
+<li><p><strong>columns</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>str</em>) – </p></li>
+<li><p><strong>sequential_index</strong> (<em>int</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>selection</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -527,7 +637,7 @@ <h1 id="api-graphnet-data-dataset-parquet-parquet-dataset--page-root">parquet_da
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.dataset.sqlite.html b/api/graphnet.data.dataset.sqlite.html
index 9cfc9a2a5..8ea0f5783 100644
--- a/api/graphnet.data.dataset.sqlite.html
+++ b/api/graphnet.data.dataset.sqlite.html
@@ -482,13 +482,20 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="sqlite">
-<h1 id="api-graphnet-data-dataset-sqlite--page-root">sqlite<a class="headerlink" href="#api-graphnet-data-dataset-sqlite--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.data.dataset.sqlite">
+<span id="sqlite"></span><h1 id="api-graphnet-data-dataset-sqlite--page-root">sqlite<a class="headerlink" href="#api-graphnet-data-dataset-sqlite--page-root" title="Link to this heading">¶</a></h1>
+<p>Datasets using SQLite backend.</p>
 <p><h2> Submodules </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.sqlite.sqlite_dataset.html">sqlite_dataset</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html">sqlite_dataset_perturbed</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.sqlite.sqlite_dataset.html">sqlite_dataset</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.sqlite.sqlite_dataset.html#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset"><code class="docutils literal notranslate"><span class="pre">SQLiteDataset</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html">sqlite_dataset_perturbed</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html#graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed"><code class="docutils literal notranslate"><span class="pre">SQLiteDatasetPerturbed</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -540,7 +547,7 @@ <h1 id="api-graphnet-data-dataset-sqlite--page-root">sqlite<a class="headerlink"
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.dataset.sqlite.sqlite_dataset.html b/api/graphnet.data.dataset.sqlite.sqlite_dataset.html
index 36b0988a6..7d25db896 100644
--- a/api/graphnet.data.dataset.sqlite.sqlite_dataset.html
+++ b/api/graphnet.data.dataset.sqlite.sqlite_dataset.html
@@ -335,11 +335,36 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-data-dataset-sqlite-sqlite-dataset--page-root" class="md-nav__link">sqlite_dataset</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">SQLiteDataset</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">SQLiteDataset</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -473,7 +498,18 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-data-dataset-sqlite-sqlite-dataset--page-root" class="md-nav__link">sqlite_dataset</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">SQLiteDataset</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -483,8 +519,82 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="sqlite-dataset">
-<h1 id="api-graphnet-data-dataset-sqlite-sqlite-dataset--page-root">sqlite_dataset<a class="headerlink" href="#api-graphnet-data-dataset-sqlite-sqlite-dataset--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.data.dataset.sqlite.sqlite_dataset">
+<span id="sqlite-dataset"></span><h1 id="api-graphnet-data-dataset-sqlite-sqlite-dataset--page-root">sqlite_dataset<a class="headerlink" href="#api-graphnet-data-dataset-sqlite-sqlite-dataset--page-root" title="Link to this heading">¶</a></h1>
+<p><cite>Dataset</cite> class(es) for reading data from SQLite databases.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.sqlite.sqlite_dataset.</span></span><span class="sig-name descname"><span class="pre">SQLiteDataset</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemaps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">index_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">string_selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_default_value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">seed</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html#SQLiteDataset"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code></a></p>
+<p>Pytorch dataset for reading data from SQLite databases.</p>
+<p>Construct Dataset.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>path</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Path to the file(s) from which this <cite>Dataset</cite> should read.</p></li>
+<li><p><strong>pulsemaps</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Name(s) of the pulse map series that should be used to
+construct the nodes on the individual graph objects, and their
+features. Multiple pulse series maps can be used, e.g., when
+different DOM types are stored in different maps.</p></li>
+<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of columns in the input files that should be used as
+node features on the graph objects.</p></li>
+<li><p><strong>truth</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of event-level columns in the input files that should
+be used added as attributes on the  graph objects.</p></li>
+<li><p><strong>node_truth</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – List of node-level columns in the input files that
+should be used added as attributes on the graph objects.</p></li>
+<li><p><strong>index_column</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'event_no'</span></code>) – Name of the column in the input files that contains
+unique indicies to identify and map events across tables.</p></li>
+<li><p><strong>truth_table</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'truth'</span></code>) – Name of the table containing event-level truth
+information.</p></li>
+<li><p><strong>node_truth_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing node-level truth
+information.</p></li>
+<li><p><strong>string_selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Subset of strings for which data should be read
+and used to construct graph objects. Defaults to None, meaning
+all strings for which data exists are used.</p></li>
+<li><p><strong>selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The events that should be read. This can be given either
+as list of indicies (in <cite>index_column</cite>); or a string-based
+selection used to query the <cite>Dataset</cite> for events passing the
+selection. Defaults to None, meaning that all events in the
+input files are read.</p></li>
+<li><p><strong>dtype</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">dtype</span></code>, default: <code class="docutils literal notranslate"><span class="pre">torch.float32</span></code>) – Type of the feature tensor on the graph objects returned.</p></li>
+<li><p><strong>loss_weight_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing per-event loss
+weights.</p></li>
+<li><p><strong>loss_weight_column</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the column in <cite>loss_weight_table</cite>
+containing per-event loss weights. This is also the name of the
+corresponding attribute assigned to the graph object.</p></li>
+<li><p><strong>loss_weight_default_value</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Default per-event loss weight.
+NOTE: This default value is only applied when
+<cite>loss_weight_table</cite> and <cite>loss_weight_column</cite> are specified, and
+in this case to events with no value in the corresponding
+table/column. That is, if no per-event loss weight table/column
+is provided, this value is ignored. Defaults to None.</p></li>
+<li><p><strong>seed</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Random number generator seed, used for selecting a random
+subset of events when resolving a string-based selection (e.g.,
+<cite>“10000 random events ~ event_no % 5 &gt; 0”</cite> or <cite>“20% random
+events ~ event_no % 5 &gt; 0”</cite>).</p></li>
+<li><p><strong>graph_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a>) – Method that defines the graph representation.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table">
+<span class="sig-name descname"><span class="pre">query_table</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">sequential_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html#SQLiteDataset.query_table"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table" title="Link to this definition">¶</a></dt>
+<dd><p>Query table at a specific index, optionally with some selection.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">...</span></code>]]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>table</strong> (<em>str</em>) – </p></li>
+<li><p><strong>columns</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>str</em>) – </p></li>
+<li><p><strong>sequential_index</strong> (<em>int</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>selection</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -534,7 +644,7 @@ <h1 id="api-graphnet-data-dataset-sqlite-sqlite-dataset--page-root">sqlite_datas
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html b/api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html
index e5b66ee43..f46386964 100644
--- a/api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html
+++ b/api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html
@@ -342,11 +342,25 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root" class="md-nav__link">sqlite_dataset_perturbed</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">SQLiteDatasetPerturbed</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">SQLiteDatasetPerturbed</span></code></a>
       
     
+    </li></ul>
+    
     </li></ul>
     
     </li>
@@ -473,7 +487,14 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root" class="md-nav__link">sqlite_dataset_perturbed</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">SQLiteDatasetPerturbed</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -483,8 +504,65 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="sqlite-dataset-perturbed">
-<h1 id="api-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root">sqlite_dataset_perturbed<a class="headerlink" href="#api-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed">
+<span id="sqlite-dataset-perturbed"></span><h1 id="api-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root">sqlite_dataset_perturbed<a class="headerlink" href="#api-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root" title="Link to this heading">¶</a></h1>
+<p><cite>Dataset</cite> class(es) for reading perturbed data from SQLite databases.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.</span></span><span class="sig-name descname"><span class="pre">SQLiteDatasetPerturbed</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemaps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">perturbation_dict</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">index_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">string_selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_default_value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">seed</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html#SQLiteDatasetPerturbed"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.data.dataset.sqlite.sqlite_dataset.html#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset" title="graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">SQLiteDataset</span></code></a></p>
+<p>Pytorch dataset for reading perturbed data from SQLite databases.</p>
+<p>This including a pre-processing step, where the input data is randomly
+perturbed according to given per-feature “noise” levels. This is intended
+to test the stability of a trained model under small changes to the input
+parameters.</p>
+<p>Construct SQLiteDatasetPerturbed.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>path</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Path to the file(s) from which this <cite>Dataset</cite> should read.</p></li>
+<li><p><strong>pulsemaps</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Name(s) of the pulse map series that should be used to
+construct the nodes on the individual graph objects, and their
+features. Multiple pulse series maps can be used, e.g., when
+different DOM types are stored in different maps.</p></li>
+<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of columns in the input files that should be used as
+node features on the graph objects.</p></li>
+<li><p><strong>truth</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of event-level columns in the input files that should
+be used added as attributes on the  graph objects.</p></li>
+<li><p><strong>perturbation_dict</strong> (<em>Dict</em><em>[</em><em>str</em><em>, </em><em>float</em><em>]</em>) – Dictionary mapping a feature
+name to a standard deviation according to which the values for
+this feature should be randomly perturbed.</p></li>
+<li><p><strong>node_truth</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – List of node-level columns in the input files that
+should be used added as attributes on the graph objects.</p></li>
+<li><p><strong>index_column</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'event_no'</span></code>) – Name of the column in the input files that contains
+unique indicies to identify and map events across tables.</p></li>
+<li><p><strong>truth_table</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'truth'</span></code>) – Name of the table containing event-level truth
+information.</p></li>
+<li><p><strong>node_truth_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing node-level truth
+information.</p></li>
+<li><p><strong>string_selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Subset of strings for which data should be read
+and used to construct graph objects. Defaults to None, meaning
+all strings for which data exists are used.</p></li>
+<li><p><strong>selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – List of indicies (in <cite>index_column</cite>) of the events in
+the input files that should be read. Defaults to None, meaning
+that all events in the input files are read.</p></li>
+<li><p><strong>dtype</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">dtype</span></code>, default: <code class="docutils literal notranslate"><span class="pre">torch.float32</span></code>) – Type of the feature tensor on the graph objects returned.</p></li>
+<li><p><strong>loss_weight_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing per-event loss
+weights.</p></li>
+<li><p><strong>loss_weight_column</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the column in <cite>loss_weight_table</cite>
+containing per-event loss weights. This is also the name of the
+corresponding attribute assigned to the graph object.</p></li>
+<li><p><strong>loss_weight_default_value</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Default per-event loss weight.
+NOTE: This default value is only applied when
+<cite>loss_weight_table</cite> and <cite>loss_weight_column</cite> are specified, and
+in this case to events with no value in the corresponding
+table/column. That is, if no per-event loss weight table/column
+is provided, this value is ignored. Defaults to None.</p></li>
+<li><p><strong>seed</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Generator</span></code>, <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional seed for random number generation. Defaults to None.</p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -534,7 +612,7 @@ <h1 id="api-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root">sq
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.html b/api/graphnet.data.extractors.html
index df9ce380e..5810ea8fc 100644
--- a/api/graphnet.data.extractors.html
+++ b/api/graphnet.data.extractors.html
@@ -658,7 +658,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.i3extractor.html b/api/graphnet.data.extractors.i3extractor.html
index 884ec4690..4671c09f7 100644
--- a/api/graphnet.data.extractors.i3extractor.html
+++ b/api/graphnet.data.extractors.i3extractor.html
@@ -732,7 +732,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.i3featureextractor.html b/api/graphnet.data.extractors.i3featureextractor.html
index fd27d93af..520605dfe 100644
--- a/api/graphnet.data.extractors.i3featureextractor.html
+++ b/api/graphnet.data.extractors.i3featureextractor.html
@@ -720,7 +720,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.i3genericextractor.html b/api/graphnet.data.extractors.i3genericextractor.html
index fffb0ae7a..2915de7ce 100644
--- a/api/graphnet.data.extractors.i3genericextractor.html
+++ b/api/graphnet.data.extractors.i3genericextractor.html
@@ -639,7 +639,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.i3hybridrecoextractor.html b/api/graphnet.data.extractors.i3hybridrecoextractor.html
index c13a47727..e23c93bd9 100644
--- a/api/graphnet.data.extractors.i3hybridrecoextractor.html
+++ b/api/graphnet.data.extractors.i3hybridrecoextractor.html
@@ -623,7 +623,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html b/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html
index d87d2a70a..f9a853695 100644
--- a/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html
+++ b/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html
@@ -626,7 +626,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.i3particleextractor.html b/api/graphnet.data.extractors.i3particleextractor.html
index ed80239fc..fbdb86f6e 100644
--- a/api/graphnet.data.extractors.i3particleextractor.html
+++ b/api/graphnet.data.extractors.i3particleextractor.html
@@ -625,7 +625,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.i3pisaextractor.html b/api/graphnet.data.extractors.i3pisaextractor.html
index 82a7edbf0..4471bf09a 100644
--- a/api/graphnet.data.extractors.i3pisaextractor.html
+++ b/api/graphnet.data.extractors.i3pisaextractor.html
@@ -623,7 +623,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.i3quesoextractor.html b/api/graphnet.data.extractors.i3quesoextractor.html
index 155c83c34..649bb0171 100644
--- a/api/graphnet.data.extractors.i3quesoextractor.html
+++ b/api/graphnet.data.extractors.i3quesoextractor.html
@@ -626,7 +626,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.i3retroextractor.html b/api/graphnet.data.extractors.i3retroextractor.html
index 579e01cb4..86f076405 100644
--- a/api/graphnet.data.extractors.i3retroextractor.html
+++ b/api/graphnet.data.extractors.i3retroextractor.html
@@ -623,7 +623,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.i3splinempeextractor.html b/api/graphnet.data.extractors.i3splinempeextractor.html
index fd7c55164..f229b0090 100644
--- a/api/graphnet.data.extractors.i3splinempeextractor.html
+++ b/api/graphnet.data.extractors.i3splinempeextractor.html
@@ -623,7 +623,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.i3truthextractor.html b/api/graphnet.data.extractors.i3truthextractor.html
index 8e211b71a..213393e35 100644
--- a/api/graphnet.data.extractors.i3truthextractor.html
+++ b/api/graphnet.data.extractors.i3truthextractor.html
@@ -630,7 +630,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.i3tumextractor.html b/api/graphnet.data.extractors.i3tumextractor.html
index 4b4ad2d29..326699de0 100644
--- a/api/graphnet.data.extractors.i3tumextractor.html
+++ b/api/graphnet.data.extractors.i3tumextractor.html
@@ -623,7 +623,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.utilities.collections.html b/api/graphnet.data.extractors.utilities.collections.html
index d963c742e..0970c915a 100644
--- a/api/graphnet.data.extractors.utilities.collections.html
+++ b/api/graphnet.data.extractors.utilities.collections.html
@@ -708,7 +708,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.utilities.frames.html b/api/graphnet.data.extractors.utilities.frames.html
index f895882b1..2fcdda991 100644
--- a/api/graphnet.data.extractors.utilities.frames.html
+++ b/api/graphnet.data.extractors.utilities.frames.html
@@ -709,7 +709,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.utilities.html b/api/graphnet.data.extractors.utilities.html
index da9334938..d7361ced0 100644
--- a/api/graphnet.data.extractors.utilities.html
+++ b/api/graphnet.data.extractors.utilities.html
@@ -640,7 +640,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.extractors.utilities.types.html b/api/graphnet.data.extractors.utilities.types.html
index 6c8154f02..10266dd2d 100644
--- a/api/graphnet.data.extractors.utilities.types.html
+++ b/api/graphnet.data.extractors.utilities.types.html
@@ -867,7 +867,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.html b/api/graphnet.data.html
index dff3daeae..57b5f245a 100644
--- a/api/graphnet.data.html
+++ b/api/graphnet.data.html
@@ -507,8 +507,16 @@
 <li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataconverter.html#graphnet.data.dataconverter.DataConverter"><code class="docutils literal notranslate"><span class="pre">DataConverter</span></code></a></li>
 </ul>
 </li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataloader.html">dataloader</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.data.pipeline.html">pipeline</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataloader.html">dataloader</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataloader.html#graphnet.data.dataloader.collate_fn"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataloader.html#graphnet.data.dataloader.do_shuffle"><code class="docutils literal notranslate"><span class="pre">do_shuffle()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataloader.html#graphnet.data.dataloader.DataLoader"><code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.data.pipeline.html">pipeline</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.data.pipeline.html#graphnet.data.pipeline.InSQLitePipeline"><code class="docutils literal notranslate"><span class="pre">InSQLitePipeline</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -560,7 +568,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.parquet.html b/api/graphnet.data.parquet.html
index 31148efc3..ebb6a3f5f 100644
--- a/api/graphnet.data.parquet.html
+++ b/api/graphnet.data.parquet.html
@@ -514,7 +514,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.parquet.parquet_dataconverter.html b/api/graphnet.data.parquet.parquet_dataconverter.html
index 3b17b9603..a0f810016 100644
--- a/api/graphnet.data.parquet.parquet_dataconverter.html
+++ b/api/graphnet.data.parquet.parquet_dataconverter.html
@@ -665,7 +665,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.pipeline.html b/api/graphnet.data.pipeline.html
index 633c55e35..3e14a0e83 100644
--- a/api/graphnet.data.pipeline.html
+++ b/api/graphnet.data.pipeline.html
@@ -372,11 +372,25 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-data-pipeline--page-root" class="md-nav__link">pipeline</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.pipeline.InSQLitePipeline" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">InSQLitePipeline</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.data.pipeline.InSQLitePipeline" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">InSQLitePipeline</span></code></a>
       
     
+    </li></ul>
+    
     </li></ul>
     
     </li>
@@ -436,7 +450,14 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-data-pipeline--page-root" class="md-nav__link">pipeline</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.data.pipeline.InSQLitePipeline" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">InSQLitePipeline</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -446,8 +467,36 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="pipeline">
-<h1 id="api-graphnet-data-pipeline--page-root">pipeline<a class="headerlink" href="#api-graphnet-data-pipeline--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.data.pipeline">
+<span id="pipeline"></span><h1 id="api-graphnet-data-pipeline--page-root">pipeline<a class="headerlink" href="#api-graphnet-data-pipeline--page-root" title="Link to this heading">¶</a></h1>
+<p>Class(es) used for analysis in PISA.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.data.pipeline.InSQLitePipeline">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.pipeline.</span></span><span class="sig-name descname"><span class="pre">InSQLitePipeline</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">module_dict</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">device</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retro_table_name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">outdir</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_workers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pipeline_name</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/pipeline.html#InSQLitePipeline"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.pipeline.InSQLitePipeline" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">ABC</span></code>, <a class="reference internal" href="graphnet.utilities.logging.html#graphnet.utilities.logging.Logger" title="graphnet.utilities.logging.Logger"><code class="xref py py-class docutils literal notranslate"><span class="pre">Logger</span></code></a></p>
+<p>Create a SQLite database for PISA analysis.</p>
+<p>The database will contain truth and GNN predictions and, if available,
+RETRO reconstructions.</p>
+<p>Initialise the pipeline.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>module_dict</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>) – A dictionary with GNN modules from GraphNet. E.g.
+{‘energy’: gnn_module_for_energy_regression}</p></li>
+<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of input features for the GNN modules.</p></li>
+<li><p><strong>truth</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of truth for the GNN ModuleList.</p></li>
+<li><p><strong>device</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">device</span></code>) – The device used for computation.</p></li>
+<li><p><strong>retro_table_name</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'retro'</span></code>) – Name of the retro table for.</p></li>
+<li><p><strong>outdir</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – the directory in which the pipeline database will be
+stored.</p></li>
+<li><p><strong>batch_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">100</span></code>) – Batch size for inference.</p></li>
+<li><p><strong>n_workers</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">10</span></code>) – Number of workers used in dataloading.</p></li>
+<li><p><strong>pipeline_name</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'pipeline'</span></code>) – Name of the pipeline. If such a pipeline already
+exists, an error will be prompted to avoid overwriting.</p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -497,7 +546,7 @@ <h1 id="api-graphnet-data-pipeline--page-root">pipeline<a class="headerlink" hre
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.sqlite.html b/api/graphnet.data.sqlite.html
index 16343c670..cd932166c 100644
--- a/api/graphnet.data.sqlite.html
+++ b/api/graphnet.data.sqlite.html
@@ -534,7 +534,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.sqlite.sqlite_dataconverter.html b/api/graphnet.data.sqlite.sqlite_dataconverter.html
index 302c9e50d..dfbb879df 100644
--- a/api/graphnet.data.sqlite.sqlite_dataconverter.html
+++ b/api/graphnet.data.sqlite.sqlite_dataconverter.html
@@ -776,7 +776,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.sqlite.sqlite_utilities.html b/api/graphnet.data.sqlite.sqlite_utilities.html
index b11e73e50..514fbdacf 100644
--- a/api/graphnet.data.sqlite.sqlite_utilities.html
+++ b/api/graphnet.data.sqlite.sqlite_utilities.html
@@ -727,7 +727,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.utilities.html b/api/graphnet.data.utilities.html
index 35675598e..894827708 100644
--- a/api/graphnet.data.utilities.html
+++ b/api/graphnet.data.utilities.html
@@ -536,7 +536,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.utilities.parquet_to_sqlite.html b/api/graphnet.data.utilities.parquet_to_sqlite.html
index eca0271ac..86d85b9eb 100644
--- a/api/graphnet.data.utilities.parquet_to_sqlite.html
+++ b/api/graphnet.data.utilities.parquet_to_sqlite.html
@@ -591,7 +591,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.utilities.random.html b/api/graphnet.data.utilities.random.html
index e3b1572b6..ee45ec2a4 100644
--- a/api/graphnet.data.utilities.random.html
+++ b/api/graphnet.data.utilities.random.html
@@ -563,7 +563,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.data.utilities.string_selection_resolver.html b/api/graphnet.data.utilities.string_selection_resolver.html
index 2ee151d57..2be85aed2 100644
--- a/api/graphnet.data.utilities.string_selection_resolver.html
+++ b/api/graphnet.data.utilities.string_selection_resolver.html
@@ -552,7 +552,7 @@
 <dl class="field-list simple">
 <dt class="field-odd">Parameters<span class="colon">:</span></dt>
 <dd class="field-odd"><ul class="simple">
-<li><p><strong>dataset</strong> (<em>Dataset</em>) – </p></li>
+<li><p><strong>dataset</strong> (<a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><em>Dataset</em></a>) – </p></li>
 <li><p><strong>index_column</strong> (<em>str</em>) – </p></li>
 <li><p><strong>seed</strong> (<em>int</em><em> | </em><em>None</em>) – </p></li>
 <li><p><strong>use_cache</strong> (<em>bool</em>) – </p></li>
@@ -626,7 +626,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.deployment.html b/api/graphnet.deployment.html
index 534adda23..a21590d71 100644
--- a/api/graphnet.deployment.html
+++ b/api/graphnet.deployment.html
@@ -453,7 +453,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.deployment.i3modules.deployer.html b/api/graphnet.deployment.i3modules.deployer.html
index 5a1c55118..c29615e41 100644
--- a/api/graphnet.deployment.i3modules.deployer.html
+++ b/api/graphnet.deployment.i3modules.deployer.html
@@ -456,7 +456,7 @@ <h1 id="api-graphnet-deployment-i3modules-deployer--page-root">deployer<a class=
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.deployment.i3modules.graphnet_module.html b/api/graphnet.deployment.i3modules.graphnet_module.html
index 0005fa3ad..ac93cd1ea 100644
--- a/api/graphnet.deployment.i3modules.graphnet_module.html
+++ b/api/graphnet.deployment.i3modules.graphnet_module.html
@@ -336,11 +336,43 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-deployment-i3modules-graphnet-module--page-root" class="md-nav__link">graphnet_module</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GraphNeTI3Module</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">I3InferenceModule</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">I3PulseCleanerModule</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GraphNeTI3Module</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">I3InferenceModule</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">I3PulseCleanerModule</span></code></a>
+      
+    
+    </li></ul>
+    
     </li></ul>
     
     </li></ul>
@@ -395,7 +427,18 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-deployment-i3modules-graphnet-module--page-root" class="md-nav__link">graphnet_module</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GraphNeTI3Module</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">I3InferenceModule</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">I3PulseCleanerModule</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -405,8 +448,94 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="graphnet-module">
-<h1 id="api-graphnet-deployment-i3modules-graphnet-module--page-root">graphnet_module<a class="headerlink" href="#api-graphnet-deployment-i3modules-graphnet-module--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.deployment.i3modules.graphnet_module">
+<span id="graphnet-module"></span><h1 id="api-graphnet-deployment-i3modules-graphnet-module--page-root">graphnet_module<a class="headerlink" href="#api-graphnet-deployment-i3modules-graphnet-module--page-root" title="Link to this heading">¶</a></h1>
+<p>Class(es) for deploying GraphNeT models in icetray as I3Modules.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.deployment.i3modules.graphnet_module.</span></span><span class="sig-name descname"><span class="pre">GraphNeTI3Module</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemap</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemap_extractor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gcd_file</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/deployment/i3modules/graphnet_module.html#GraphNeTI3Module"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">object</span></code></p>
+<p>Base I3 Module for GraphNeT.</p>
+<p>Contains methods for extracting pulsemaps, producing graphs and writing to
+frames.</p>
+<p>I3Module Constructor.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>graph_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a>) – An instance of GraphDefinition.  E.g. KNNGraph.</p></li>
+<li><p><strong>pulsemap</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – the pulse map on which the module functions</p></li>
+<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – the features that is used from the pulse map.
+E.g. [dom_x, dom_y, dom_z, charge]</p></li>
+<li><p><strong>pulsemap_extractor</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<a class="reference internal" href="graphnet.data.extractors.i3featureextractor.html#graphnet.data.extractors.i3featureextractor.I3FeatureExtractor" title="graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"><code class="xref py py-class docutils literal notranslate"><span class="pre">I3FeatureExtractor</span></code></a>], <a class="reference internal" href="graphnet.data.extractors.i3featureextractor.html#graphnet.data.extractors.i3featureextractor.I3FeatureExtractor" title="graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"><code class="xref py py-class docutils literal notranslate"><span class="pre">I3FeatureExtractor</span></code></a>]) – The I3FeatureExtractor used to extract the
+pulsemap from the I3Frames</p></li>
+<li><p><strong>gcd_file</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – Path to the associated gcd-file.</p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.deployment.i3modules.graphnet_module.I3InferenceModule">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.deployment.i3modules.graphnet_module.</span></span><span class="sig-name descname"><span class="pre">I3InferenceModule</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pulsemap</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemap_extractor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model_config</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">state_dict</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model_name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gcd_file</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_columns</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/deployment/i3modules/graphnet_module.html#I3InferenceModule"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module" title="graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">GraphNeTI3Module</span></code></a></p>
+<p>General class for inference on i3 frames.</p>
+<p>General class for inference on I3Frames (physics).</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>pulsemap</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – the pulsmap that the model is expecting as input.</p></li>
+<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – the features of the pulsemap that the model is expecting.</p></li>
+<li><p><strong>pulsemap_extractor</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<a class="reference internal" href="graphnet.data.extractors.i3featureextractor.html#graphnet.data.extractors.i3featureextractor.I3FeatureExtractor" title="graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"><code class="xref py py-class docutils literal notranslate"><span class="pre">I3FeatureExtractor</span></code></a>], <a class="reference internal" href="graphnet.data.extractors.i3featureextractor.html#graphnet.data.extractors.i3featureextractor.I3FeatureExtractor" title="graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"><code class="xref py py-class docutils literal notranslate"><span class="pre">I3FeatureExtractor</span></code></a>]) – The extractor used to extract the pulsemap.</p></li>
+<li><p><strong>model_config</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<a class="reference internal" href="graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig" title="graphnet.utilities.config.model_config.ModelConfig"><code class="xref py py-class docutils literal notranslate"><span class="pre">ModelConfig</span></code></a>, <code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – The ModelConfig (or path to it) that summarizes the
+model used for inference.</p></li>
+<li><p><strong>state_dict</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – Path to state_dict containing the learned weights.</p></li>
+<li><p><strong>model_name</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – The name used for the model. Will help define the
+named entry in the I3Frame. E.g. “dynedge”.</p></li>
+<li><p><strong>gcd_file</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – path to associated gcd file.</p></li>
+<li><p><strong>prediction_columns</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – <p>column names for the predictions of the model.
+Will help define the named entry in the I3Frame.</p>
+<blockquote>
+<div><p>E.g. [‘energy_reco’]. Optional.</p>
+</div></blockquote>
+</p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.deployment.i3modules.graphnet_module.</span></span><span class="sig-name descname"><span class="pre">I3PulseCleanerModule</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pulsemap</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemap_extractor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model_config</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">state_dict</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model_name</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gcd_file</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">threshold</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">discard_empty_events</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_columns</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/deployment/i3modules/graphnet_module.html#I3PulseCleanerModule"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule" title="graphnet.deployment.i3modules.graphnet_module.I3InferenceModule"><code class="xref py py-class docutils literal notranslate"><span class="pre">I3InferenceModule</span></code></a></p>
+<p>A specialized module for pulse cleaning.</p>
+<p>It is assumed that the model provided has been trained for this.</p>
+<p>General class for inference on I3Frames (physics).</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>pulsemap</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – the pulsmap that the model is expecting as input
+(the one that is being cleaned).</p></li>
+<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – the features of the pulsemap that the model is expecting.</p></li>
+<li><p><strong>pulsemap_extractor</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<a class="reference internal" href="graphnet.data.extractors.i3featureextractor.html#graphnet.data.extractors.i3featureextractor.I3FeatureExtractor" title="graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"><code class="xref py py-class docutils literal notranslate"><span class="pre">I3FeatureExtractor</span></code></a>], <a class="reference internal" href="graphnet.data.extractors.i3featureextractor.html#graphnet.data.extractors.i3featureextractor.I3FeatureExtractor" title="graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"><code class="xref py py-class docutils literal notranslate"><span class="pre">I3FeatureExtractor</span></code></a>]) – The extractor used to extract the pulsemap.</p></li>
+<li><p><strong>model_config</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – The ModelConfig (or path to it) that summarizes the
+model used for inference.</p></li>
+<li><p><strong>state_dict</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – Path to state_dict containing the learned weights.</p></li>
+<li><p><strong>model_name</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – The name used for the model. Will help define the named
+entry in the I3Frame. E.g. “dynedge”.</p></li>
+<li><p><strong>gcd_file</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – path to associated gcd file.</p></li>
+<li><p><strong>threshold</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>, default: <code class="docutils literal notranslate"><span class="pre">0.7</span></code>) – the threshold for being considered a positive case.
+E.g., predictions &gt;= threshold will be considered
+to be signal, all else noise.</p></li>
+<li><p><strong>discard_empty_events</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code>, default: <code class="docutils literal notranslate"><span class="pre">False</span></code>) – When true, this flag will eliminate events
+whose cleaned pulse series are empty. Can be used
+to speed up processing especially for noise
+simulation, since it will not do any writing or
+further calculations.</p></li>
+<li><p><strong>prediction_columns</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – column names for the predictions of the model.
+Will help define the named entry in the I3Frame.
+E.g. [‘energy_reco’]. Optional.</p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -456,7 +585,7 @@ <h1 id="api-graphnet-deployment-i3modules-graphnet-module--page-root">graphnet_m
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.deployment.i3modules.html b/api/graphnet.deployment.i3modules.html
index cb57f2df6..975ce8207 100644
--- a/api/graphnet.deployment.i3modules.html
+++ b/api/graphnet.deployment.i3modules.html
@@ -410,7 +410,12 @@ <h1 id="api-graphnet-deployment-i3modules--page-root">i3modules<a class="headerl
 <div class="toctree-wrapper compound">
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="graphnet.deployment.i3modules.deployer.html">deployer</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.deployment.i3modules.graphnet_module.html">graphnet_module</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.deployment.i3modules.graphnet_module.html">graphnet_module</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module"><code class="docutils literal notranslate"><span class="pre">GraphNeTI3Module</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule"><code class="docutils literal notranslate"><span class="pre">I3InferenceModule</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule"><code class="docutils literal notranslate"><span class="pre">I3PulseCleanerModule</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -462,7 +467,7 @@ <h1 id="api-graphnet-deployment-i3modules--page-root">i3modules<a class="headerl
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.html b/api/graphnet.html
index b04aa430c..5cee1a97c 100644
--- a/api/graphnet.html
+++ b/api/graphnet.html
@@ -522,7 +522,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.coarsening.html b/api/graphnet.models.coarsening.html
index 1dbd5b679..5341555ea 100644
--- a/api/graphnet.models.coarsening.html
+++ b/api/graphnet.models.coarsening.html
@@ -125,6 +125,7 @@
     <script src="../_static/documentation_options.js?v=5929fcd5"></script>
     <script src="../_static/doctools.js?v=888ff710"></script>
     <script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
     <link rel="icon" href="../_static/favicon.svg"/>
     <link rel="author" title="About these documents" href="../about.html" />
     <link rel="index" title="Index" href="../genindex.html" />
@@ -365,11 +366,90 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-coarsening--page-root" class="md-nav__link">coarsening</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.unbatch_edge_index" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">unbatch_edge_index()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.Coarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Coarsening</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.Coarsening.reduce_options" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">reduce_options</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.Coarsening.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.AttributeCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AttributeCoarsening</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.DOMCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DOMCoarsening</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.CustomDOMCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">CustomDOMCoarsening</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.DOMAndTimeWindowCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DOMAndTimeWindowCoarsening</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.coarsening.unbatch_edge_index" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">unbatch_edge_index()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.coarsening.Coarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Coarsening</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.coarsening.Coarsening.reduce_options" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">reduce_options</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.coarsening.Coarsening.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.coarsening.AttributeCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AttributeCoarsening</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.coarsening.DOMCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DOMCoarsening</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.coarsening.CustomDOMCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">CustomDOMCoarsening</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.coarsening.DOMAndTimeWindowCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DOMAndTimeWindowCoarsening</span></code></a>
       
     
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -436,7 +516,30 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-coarsening--page-root" class="md-nav__link">coarsening</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.unbatch_edge_index" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">unbatch_edge_index()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.Coarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Coarsening</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.Coarsening.reduce_options" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">reduce_options</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.Coarsening.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.AttributeCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AttributeCoarsening</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.DOMCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DOMCoarsening</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.CustomDOMCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">CustomDOMCoarsening</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.coarsening.DOMAndTimeWindowCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DOMAndTimeWindowCoarsening</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -446,8 +549,125 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="coarsening">
-<h1 id="api-graphnet-models-coarsening--page-root">coarsening<a class="headerlink" href="#api-graphnet-models-coarsening--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.coarsening">
+<span id="coarsening"></span><h1 id="api-graphnet-models-coarsening--page-root">coarsening<a class="headerlink" href="#api-graphnet-models-coarsening--page-root" title="Link to this heading">¶</a></h1>
+<p>Class(es) for coarsening operations (i.e., clustering, or local pooling).</p>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.models.coarsening.unbatch_edge_index">
+<span class="sig-prename descclassname"><span class="pre">graphnet.models.coarsening.</span></span><span class="sig-name descname"><span class="pre">unbatch_edge_index</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">edge_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/coarsening.html#unbatch_edge_index"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.coarsening.unbatch_edge_index" title="Link to this definition">¶</a></dt>
+<dd><p>Splits the <code class="xref py py-obj docutils literal notranslate"><span class="pre">edge_index</span></code> according to a <code class="xref py py-obj docutils literal notranslate"><span class="pre">batch</span></code> vector.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>edge_index</strong> (<em>Tensor</em>) – The edge_index tensor. Must be ordered.</p></li>
+<li><p><strong>batch</strong> (<em>LongTensor</em>) – The batch vector
+<span class="math notranslate nohighlight">\(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\)</span>, which assigns each
+node to a specific example. Must be ordered.</p></li>
+</ul>
+</dd>
+<dt class="field-even">Return type<span class="colon">:</span></dt>
+<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List[Tensor]</span></code></p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.coarsening.Coarsening">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.coarsening.</span></span><span class="sig-name descname"><span class="pre">Coarsening</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">reduce</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transfer_attributes</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/coarsening.html#Coarsening"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.coarsening.Coarsening" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p>
+<p>Base class for coarsening operations.</p>
+<p>Construct <cite>Coarsening</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>reduce</strong> (<em>str</em>) – </p></li>
+<li><p><strong>transfer_attributes</strong> (<em>bool</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.coarsening.Coarsening.reduce_options">
+<span class="sig-name descname"><span class="pre">reduce_options</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{'avg':</span> <span class="pre">(&lt;function</span> <span class="pre">avg_pool&gt;,</span> <span class="pre">&lt;function</span> <span class="pre">avg_pool_x&gt;),</span> <span class="pre">'max':</span> <span class="pre">(&lt;function</span> <span class="pre">max_pool&gt;,</span> <span class="pre">&lt;function</span> <span class="pre">max_pool_x&gt;),</span> <span class="pre">'min':</span> <span class="pre">(&lt;function</span> <span class="pre">min_pool&gt;,</span> <span class="pre">&lt;function</span> <span class="pre">min_pool_x&gt;),</span> <span class="pre">'sum':</span> <span class="pre">(&lt;function</span> <span class="pre">sum_pool&gt;,</span> <span class="pre">&lt;function</span> <span class="pre">sum_pool_x&gt;)}</span></em><a class="headerlink" href="#graphnet.models.coarsening.Coarsening.reduce_options" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.coarsening.Coarsening.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/coarsening.html#Coarsening.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.coarsening.Coarsening.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Perform coarsening operation.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Batch</span></code>]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>data</strong> (<em>Data</em><em> | </em><em>Batch</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.coarsening.AttributeCoarsening">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.coarsening.</span></span><span class="sig-name descname"><span class="pre">AttributeCoarsening</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">attributes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reduce</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transfer_attributes</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/coarsening.html#AttributeCoarsening"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.coarsening.AttributeCoarsening" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.models.coarsening.Coarsening" title="graphnet.models.coarsening.Coarsening"><code class="xref py py-class docutils literal notranslate"><span class="pre">Coarsening</span></code></a></p>
+<p>Coarsen pulses based on specified attributes.</p>
+<p>Construct <cite>SimpleCoarsening</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>attributes</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>reduce</strong> (<em>str</em>) – </p></li>
+<li><p><strong>transfer_attributes</strong> (<em>bool</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.coarsening.DOMCoarsening">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.coarsening.</span></span><span class="sig-name descname"><span class="pre">DOMCoarsening</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">reduce</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transfer_attributes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">keys</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/coarsening.html#DOMCoarsening"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.coarsening.DOMCoarsening" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.models.coarsening.Coarsening" title="graphnet.models.coarsening.Coarsening"><code class="xref py py-class docutils literal notranslate"><span class="pre">Coarsening</span></code></a></p>
+<p>Coarsen pulses to DOM-level.</p>
+<p>Cluster pulses on the same DOM.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>reduce</strong> (<em>str</em>) – </p></li>
+<li><p><strong>transfer_attributes</strong> (<em>bool</em>) – </p></li>
+<li><p><strong>keys</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.coarsening.CustomDOMCoarsening">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.coarsening.</span></span><span class="sig-name descname"><span class="pre">CustomDOMCoarsening</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">reduce</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transfer_attributes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">keys</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/coarsening.html#CustomDOMCoarsening"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.coarsening.CustomDOMCoarsening" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.models.coarsening.DOMCoarsening" title="graphnet.models.coarsening.DOMCoarsening"><code class="xref py py-class docutils literal notranslate"><span class="pre">DOMCoarsening</span></code></a></p>
+<p>Coarsen pulses to DOM-level with additional attributes.</p>
+<p>Cluster pulses on the same DOM.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>reduce</strong> (<em>str</em>) – </p></li>
+<li><p><strong>transfer_attributes</strong> (<em>bool</em>) – </p></li>
+<li><p><strong>keys</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.coarsening.DOMAndTimeWindowCoarsening">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.coarsening.</span></span><span class="sig-name descname"><span class="pre">DOMAndTimeWindowCoarsening</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="pre">time_window,</span> <span class="pre">reduce,</span> <span class="pre">transfer_attributes,</span> <span class="pre">keys=['dom_x',</span> <span class="pre">'dom_y',</span> <span class="pre">'dom_z',</span> <span class="pre">'rde',</span> <span class="pre">'pmt_area'],</span> <span class="pre">time_key</span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/coarsening.html#DOMAndTimeWindowCoarsening"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.coarsening.DOMAndTimeWindowCoarsening" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.models.coarsening.Coarsening" title="graphnet.models.coarsening.Coarsening"><code class="xref py py-class docutils literal notranslate"><span class="pre">Coarsening</span></code></a></p>
+<p>Coarsen pulses to DOM-level, with additional time-window clustering.</p>
+<p>Cluster pulses on the same DOM within <cite>time_window</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>time_window</strong> (<em>float</em>) – </p></li>
+<li><p><strong>reduce</strong> (<em>str</em>) – </p></li>
+<li><p><strong>transfer_attributes</strong> (<em>bool</em>) – </p></li>
+<li><p><strong>keys</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>time_key</strong> (<em>str</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -497,7 +717,7 @@ <h1 id="api-graphnet-models-coarsening--page-root">coarsening<a class="headerlin
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.components.html b/api/graphnet.models.components.html
index 831dd56fa..a46e005c6 100644
--- a/api/graphnet.models.components.html
+++ b/api/graphnet.models.components.html
@@ -460,13 +460,31 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="components">
-<h1 id="api-graphnet-models-components--page-root">components<a class="headerlink" href="#api-graphnet-models-components--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.components">
+<span id="components"></span><h1 id="api-graphnet-models-components--page-root">components<a class="headerlink" href="#api-graphnet-models-components--page-root" title="Link to this heading">¶</a></h1>
+<p>Components for constructing models.</p>
 <p><h2> Submodules </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.components.layers.html">layers</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.components.pool.html">pool</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.components.layers.html">layers</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.layers.html#graphnet.models.components.layers.DynEdgeConv"><code class="docutils literal notranslate"><span class="pre">DynEdgeConv</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito"><code class="docutils literal notranslate"><span class="pre">EdgeConvTito</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.layers.html#graphnet.models.components.layers.DynTrans"><code class="docutils literal notranslate"><span class="pre">DynTrans</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.components.pool.html">pool</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.min_pool"><code class="docutils literal notranslate"><span class="pre">min_pool()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.min_pool_x"><code class="docutils literal notranslate"><span class="pre">min_pool_x()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool_and_distribute"><code class="docutils literal notranslate"><span class="pre">sum_pool_and_distribute()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.group_by"><code class="docutils literal notranslate"><span class="pre">group_by()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.group_pulses_to_dom"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_dom()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.group_pulses_to_pmt"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_pmt()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool_x"><code class="docutils literal notranslate"><span class="pre">sum_pool_x()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.std_pool_x"><code class="docutils literal notranslate"><span class="pre">std_pool_x()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool"><code class="docutils literal notranslate"><span class="pre">sum_pool()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.std_pool"><code class="docutils literal notranslate"><span class="pre">std_pool()</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -518,7 +536,7 @@ <h1 id="api-graphnet-models-components--page-root">components<a class="headerlin
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.components.layers.html b/api/graphnet.models.components.layers.html
index c9568c1c1..06e348db9 100644
--- a/api/graphnet.models.components.layers.html
+++ b/api/graphnet.models.components.layers.html
@@ -336,11 +336,94 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-components-layers--page-root" class="md-nav__link">layers</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynEdgeConv" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeConv</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynEdgeConv.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EdgeConvTito</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito.reset_parameters" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">reset_parameters()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito.message" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">message()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynTrans" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynTrans</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynTrans.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.layers.DynEdgeConv" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeConv</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.layers.DynEdgeConv.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.layers.EdgeConvTito" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EdgeConvTito</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.layers.EdgeConvTito.reset_parameters" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">reset_parameters()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.layers.EdgeConvTito.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.layers.EdgeConvTito.message" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">message()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.layers.DynTrans" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynTrans</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.layers.DynTrans.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -451,7 +534,34 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-components-layers--page-root" class="md-nav__link">layers</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynEdgeConv" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeConv</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynEdgeConv.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EdgeConvTito</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito.reset_parameters" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">reset_parameters()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito.message" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">message()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynTrans" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynTrans</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynTrans.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -461,8 +571,145 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="layers">
-<h1 id="api-graphnet-models-components-layers--page-root">layers<a class="headerlink" href="#api-graphnet-models-components-layers--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.components.layers">
+<span id="layers"></span><h1 id="api-graphnet-models-components-layers--page-root">layers<a class="headerlink" href="#api-graphnet-models-components-layers--page-root" title="Link to this heading">¶</a></h1>
+<p>Class(es) implementing layers to be used in <cite>graphnet</cite> models.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.components.layers.DynEdgeConv">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.components.layers.</span></span><span class="sig-name descname"><span class="pre">DynEdgeConv</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggr</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nb_neighbors</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features_subset</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#DynEdgeConv"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.DynEdgeConv" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">EdgeConv</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code></p>
+<p>Dynamical edge convolution layer.</p>
+<p>Construct <cite>DynEdgeConv</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>nn</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>) – The MLP/torch.Module to be used within the <cite>EdgeConv</cite>.</p></li>
+<li><p><strong>aggr</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'max'</span></code>) – Aggregation method to be used with <cite>EdgeConv</cite>.</p></li>
+<li><p><strong>nb_neighbors</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">8</span></code>) – Number of neighbours to be clustered after the
+<cite>EdgeConv</cite> operation.</p></li>
+<li><p><strong>features_subset</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], <code class="xref py py-class docutils literal notranslate"><span class="pre">slice</span></code>, <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Subset of features in <cite>Data.x</cite> that should be used
+when dynamically performing the new graph clustering after the
+<cite>EdgeConv</cite> operation. Defaults to all features.</p></li>
+<li><p><strong>**kwargs</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>) – Additional features to be passed to <cite>EdgeConv</cite>.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.components.layers.DynEdgeConv.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">edge_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#DynEdgeConv.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.DynEdgeConv.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Forward pass.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>x</strong> (<em>Tensor</em>) – </p></li>
+<li><p><strong>edge_index</strong> (<em>Tensor</em><em> | </em><em>SparseTensor</em>) – </p></li>
+<li><p><strong>batch</strong> (<em>Tensor</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.components.layers.EdgeConvTito">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.components.layers.</span></span><span class="sig-name descname"><span class="pre">EdgeConvTito</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggr</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#EdgeConvTito"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.EdgeConvTito" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">MessagePassing</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code></p>
+<p>Implementation of EdgeConvTito layer used in TITO solution for.</p>
+<p>‘IceCube - Neutrinos in Deep’ kaggle competition.</p>
+<p>Construct <cite>EdgeConvTito</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>nn</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>) – The MLP/torch.Module to be used within the <cite>EdgeConvTito</cite>.</p></li>
+<li><p><strong>aggr</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'max'</span></code>) – Aggregation method to be used with <cite>EdgeConvTito</cite>.</p></li>
+<li><p><strong>**kwargs</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>) – Additional features to be passed to <cite>EdgeConvTito</cite>.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.components.layers.EdgeConvTito.reset_parameters">
+<span class="sig-name descname"><span class="pre">reset_parameters</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#EdgeConvTito.reset_parameters"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.EdgeConvTito.reset_parameters" title="Link to this definition">¶</a></dt>
+<dd><p>Reset all learnable parameters of the module.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.components.layers.EdgeConvTito.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">edge_index</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#EdgeConvTito.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.EdgeConvTito.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Forward pass.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>x</strong> (<em>Tensor</em><em> | </em><em>Tuple</em><em>[</em><em>Tensor</em><em>, </em><em>Tensor</em><em>]</em>) – </p></li>
+<li><p><strong>edge_index</strong> (<em>Tensor</em><em> | </em><em>SparseTensor</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.components.layers.EdgeConvTito.message">
+<span class="sig-name descname"><span class="pre">message</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x_i</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x_j</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#EdgeConvTito.message"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.EdgeConvTito.message" title="Link to this definition">¶</a></dt>
+<dd><p>Edgeconvtito message passing.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>x_i</strong> (<em>Tensor</em>) – </p></li>
+<li><p><strong>x_j</strong> (<em>Tensor</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.components.layers.DynTrans">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.components.layers.</span></span><span class="sig-name descname"><span class="pre">DynTrans</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">layer_sizes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggr</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features_subset</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_head</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#DynTrans"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.DynTrans" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.models.components.layers.EdgeConvTito" title="graphnet.models.components.layers.EdgeConvTito"><code class="xref py py-class docutils literal notranslate"><span class="pre">EdgeConvTito</span></code></a>, <code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code></p>
+<p>Implementation of dynTrans1 layer used in TITO solution for.</p>
+<p>‘IceCube - Neutrinos in Deep’ kaggle competition.</p>
+<p>Construct <cite>DynTrans</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>nn</strong> – The MLP/torch.Module to be used within the <cite>DynTrans</cite>.</p></li>
+<li><p><strong>layer_sizes</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – List of layer sizes to be used in <cite>DynTrans</cite>.</p></li>
+<li><p><strong>aggr</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'max'</span></code>) – Aggregation method to be used with <cite>DynTrans</cite>.</p></li>
+<li><p><strong>features_subset</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], <code class="xref py py-class docutils literal notranslate"><span class="pre">slice</span></code>, <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Subset of features in <cite>Data.x</cite> that should be used
+when dynamically performing the new graph clustering after the
+<cite>EdgeConv</cite> operation. Defaults to all features.</p></li>
+<li><p><strong>n_head</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">8</span></code>) – Number of heads to be used in the multiheadattention models.</p></li>
+<li><p><strong>**kwargs</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>) – Additional features to be passed to <cite>DynTrans</cite>.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.components.layers.DynTrans.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">edge_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#DynTrans.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.DynTrans.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Forward pass.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>x</strong> (<em>Tensor</em>) – </p></li>
+<li><p><strong>edge_index</strong> (<em>Tensor</em><em> | </em><em>SparseTensor</em>) – </p></li>
+<li><p><strong>batch</strong> (<em>Tensor</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -512,7 +759,7 @@ <h1 id="api-graphnet-models-components-layers--page-root">layers<a class="header
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.components.pool.html b/api/graphnet.models.components.pool.html
index da3409715..16bcb18ac 100644
--- a/api/graphnet.models.components.pool.html
+++ b/api/graphnet.models.components.pool.html
@@ -125,6 +125,7 @@
     <script src="../_static/documentation_options.js?v=5929fcd5"></script>
     <script src="../_static/doctools.js?v=888ff710"></script>
     <script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
+    <script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
     <link rel="icon" href="../_static/favicon.svg"/>
     <link rel="author" title="About these documents" href="../about.html" />
     <link rel="index" title="Index" href="../genindex.html" />
@@ -343,11 +344,106 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-components-pool--page-root" class="md-nav__link">pool</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.min_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">min_pool()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.min_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">min_pool_x()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.sum_pool_and_distribute" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool_and_distribute()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.group_by" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_by()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.group_pulses_to_dom" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_dom()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.group_pulses_to_pmt" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_pmt()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.sum_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool_x()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.std_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">std_pool_x()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.sum_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.std_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">std_pool()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.pool.min_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">min_pool()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.pool.min_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">min_pool_x()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.pool.sum_pool_and_distribute" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool_and_distribute()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.pool.group_by" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_by()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.pool.group_pulses_to_dom" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_dom()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.pool.group_pulses_to_pmt" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_pmt()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.pool.sum_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool_x()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.pool.std_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">std_pool_x()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.pool.sum_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool()</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.components.pool.std_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">std_pool()</span></code></a>
+      
+    
+    </li></ul>
+    
     </li></ul>
     
     </li>
@@ -451,7 +547,32 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-components-pool--page-root" class="md-nav__link">pool</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.min_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">min_pool()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.min_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">min_pool_x()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.sum_pool_and_distribute" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool_and_distribute()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.group_by" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_by()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.group_pulses_to_dom" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_dom()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.group_pulses_to_pmt" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_pmt()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.sum_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool_x()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.std_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">std_pool_x()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.sum_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.components.pool.std_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">std_pool()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -461,8 +582,220 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="pool">
-<h1 id="api-graphnet-models-components-pool--page-root">pool<a class="headerlink" href="#api-graphnet-models-components-pool--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.components.pool">
+<span id="pool"></span><h1 id="api-graphnet-models-components-pool--page-root">pool<a class="headerlink" href="#api-graphnet-models-components-pool--page-root" title="Link to this heading">¶</a></h1>
+<p>Functions for performing pooling/clustering/coarsening.</p>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.models.components.pool.min_pool">
+<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">min_pool</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cluster</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">data</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#min_pool"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.min_pool" title="Link to this definition">¶</a></dt>
+<dd><p>Perform min-pooling of <cite>Data</cite>.</p>
+<p>Like <cite>max_pool, just negating `data.x</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>cluster</strong> (<em>LongTensor</em>) – </p></li>
+<li><p><strong>data</strong> (<em>Data</em>) – </p></li>
+<li><p><strong>transform</strong> (<em>Any</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.models.components.pool.min_pool_x">
+<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">min_pool_x</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cluster</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">size</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#min_pool_x"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.min_pool_x" title="Link to this definition">¶</a></dt>
+<dd><p>Perform min-pooling of <cite>Tensor</cite>.</p>
+<p>Like <cite>max_pool_x, just negating `x</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>cluster</strong> (<em>LongTensor</em>) – </p></li>
+<li><p><strong>x</strong> (<em>Tensor</em>) – </p></li>
+<li><p><strong>batch</strong> (<em>LongTensor</em>) – </p></li>
+<li><p><strong>size</strong> (<em>int</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.models.components.pool.sum_pool_and_distribute">
+<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">sum_pool_and_distribute</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">cluster_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#sum_pool_and_distribute"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.sum_pool_and_distribute" title="Link to this definition">¶</a></dt>
+<dd><p>Sum-pool values and distribute result to the individual nodes.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>tensor</strong> (<em>Tensor</em>) – </p></li>
+<li><p><strong>cluster_index</strong> (<em>LongTensor</em>) – </p></li>
+<li><p><strong>batch</strong> (<em>LongTensor</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.models.components.pool.group_by">
+<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">group_by</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">keys</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#group_by"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.group_by" title="Link to this definition">¶</a></dt>
+<dd><p>Group nodes in <cite>data</cite> that have identical values of <cite>keys</cite>.</p>
+<p>This grouping is done with in each event in case of batching. This allows
+for, e.g., assigning the same index to all pulses on the same PMT or DOM in
+the same event. This can be used for coarsening graphs, e.g., from pulse-
+level to DOM-level by aggregating feature across each group returned by this
+method.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">LongTensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>data</strong> (<em>Data</em><em> | </em><em>Batch</em>) – </p></li>
+<li><p><strong>keys</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+<p class="rubric">Example</p>
+<dl class="simple">
+<dt>Given:</dt><dd><p>data.f1 = [1,1,2,2,2]
+data.f2 = [6,7,7,7,8]</p>
+</dd>
+<dt>Calls:</dt><dd><p>groupby(data, [‘f1’])       -&gt; [0, 0, 1, 1, 1]
+groupby(data, [‘f2’])       -&gt; [0, 1, 1, 1, 2]
+groupby(data, [‘f1’, ‘f2’]) -&gt; [0, 1, 2, 2, 3]</p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.models.components.pool.group_pulses_to_dom">
+<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">group_pulses_to_dom</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#group_pulses_to_dom"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.group_pulses_to_dom" title="Link to this definition">¶</a></dt>
+<dd><p>Group pulses on the same DOM, using DOM and string number.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.models.components.pool.group_pulses_to_pmt">
+<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">group_pulses_to_pmt</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#group_pulses_to_pmt"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.group_pulses_to_pmt" title="Link to this definition">¶</a></dt>
+<dd><p>Group pulses on the same PMT, using PMT, DOM, and string number.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.models.components.pool.sum_pool_x">
+<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">sum_pool_x</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cluster</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">size</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#sum_pool_x"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.sum_pool_x" title="Link to this definition">¶</a></dt>
+<dd><p>Sum-pool node features according to the clustering defined in <cite>cluster</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>cluster</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">LongTensor</span></code>) – Cluster vector <span class="math notranslate nohighlight">\(\mathbf{c} \in \{ 0,
+\ldots, N - 1 \}^N\)</span>, which assigns each node to a specific cluster.</p></li>
+<li><p><strong>x</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>) – Node feature matrix
+<span class="math notranslate nohighlight">\(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\)</span>.</p></li>
+<li><p><strong>batch</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">LongTensor</span></code>) – Batch vector <span class="math notranslate nohighlight">\(\mathbf{b} \in {\{ 0, \ldots,
+B-1\}}^N\)</span>, which assigns each node to a specific example.</p></li>
+<li><p><strong>size</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The maximum number of clusters in a single
+example. This property is useful to obtain a batch-wise dense
+representation, <em>e.g.</em> for applying FC layers, but should only be
+used if the size of the maximum number of clusters per example is
+known in advance.</p></li>
+</ul>
+</dd>
+<dt class="field-even">Return type<span class="colon">:</span></dt>
+<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.models.components.pool.std_pool_x">
+<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">std_pool_x</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cluster</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">size</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#std_pool_x"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.std_pool_x" title="Link to this definition">¶</a></dt>
+<dd><p>Std-pool node features according to the clustering defined in <cite>cluster</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>cluster</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">LongTensor</span></code>) – Cluster vector <span class="math notranslate nohighlight">\(\mathbf{c} \in \{ 0,
+\ldots, N - 1 \}^N\)</span>, which assigns each node to a specific cluster.</p></li>
+<li><p><strong>x</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>) – Node feature matrix
+<span class="math notranslate nohighlight">\(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\)</span>.</p></li>
+<li><p><strong>batch</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">LongTensor</span></code>) – Batch vector <span class="math notranslate nohighlight">\(\mathbf{b} \in {\{ 0, \ldots,
+B-1\}}^N\)</span>, which assigns each node to a specific example.</p></li>
+<li><p><strong>size</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The maximum number of clusters in a single
+example. This property is useful to obtain a batch-wise dense
+representation, <em>e.g.</em> for applying FC layers, but should only be
+used if the size of the maximum number of clusters per example is
+known in advance.</p></li>
+</ul>
+</dd>
+<dt class="field-even">Return type<span class="colon">:</span></dt>
+<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.models.components.pool.sum_pool">
+<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">sum_pool</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cluster</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">data</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#sum_pool"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.sum_pool" title="Link to this definition">¶</a></dt>
+<dd><p>Pool and coarsen graph according to the clustering defined in <cite>cluster</cite>.</p>
+<p>All nodes within the same cluster will be represented as one node.
+Final node features are defined by the <em>sum</em> of features of all nodes
+within the same cluster, node positions are averaged and edge indices are
+defined to be the union of the edge indices of all nodes within the same
+cluster.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>cluster</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">LongTensor</span></code>) – Cluster vector <span class="math notranslate nohighlight">\(\mathbf{c} \in \{ 0,
+\ldots, N - 1 \}^N\)</span>, which assigns each node to a specific cluster.</p></li>
+<li><p><strong>data</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code>) – Graph data object.</p></li>
+<li><p><strong>transform</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – A function/transform that takes in the
+coarsened and pooled <code class="xref py py-obj docutils literal notranslate"><span class="pre">torch_geometric.data.Data</span></code> object and
+returns a transformed version.</p></li>
+</ul>
+</dd>
+<dt class="field-even">Return type<span class="colon">:</span></dt>
+<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code></p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.models.components.pool.std_pool">
+<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">std_pool</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cluster</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">data</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#std_pool"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.std_pool" title="Link to this definition">¶</a></dt>
+<dd><p>Pool and coarsen graph according to the clustering defined in <cite>cluster</cite>.</p>
+<p>All nodes within the same cluster will be represented as one node.
+Final node features are defined by the <em>std</em> of features of all nodes
+within the same cluster, node positions are averaged and edge indices are
+defined to be the union of the edge indices of all nodes within the same
+cluster.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>cluster</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">LongTensor</span></code>) – Cluster vector <span class="math notranslate nohighlight">\(\mathbf{c} \in \{ 0,
+\ldots, N - 1 \}^N\)</span>, which assigns each node to a specific cluster.</p></li>
+<li><p><strong>data</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code>) – Graph data object.</p></li>
+<li><p><strong>transform</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – A function/transform that takes in the
+coarsened and pooled <code class="xref py py-obj docutils literal notranslate"><span class="pre">torch_geometric.data.Data</span></code> object and
+returns a transformed version.</p></li>
+</ul>
+</dd>
+<dt class="field-even">Return type<span class="colon">:</span></dt>
+<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code></p>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -512,7 +845,7 @@ <h1 id="api-graphnet-models-components-pool--page-root">pool<a class="headerlink
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.detector.detector.html b/api/graphnet.models.detector.detector.html
index b8a0e72ee..773954d8d 100644
--- a/api/graphnet.models.detector.detector.html
+++ b/api/graphnet.models.detector.detector.html
@@ -343,11 +343,45 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-detector-detector--page-root" class="md-nav__link">detector</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.detector.Detector" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Detector</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.detector.Detector.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.detector.detector.Detector.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.detector.detector.Detector" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Detector</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.detector.detector.Detector.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.detector.detector.Detector.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -458,7 +492,20 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-detector-detector--page-root" class="md-nav__link">detector</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.detector.Detector" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Detector</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.detector.Detector.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.detector.detector.Detector.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -468,8 +515,44 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="detector">
-<h1 id="api-graphnet-models-detector-detector--page-root">detector<a class="headerlink" href="#api-graphnet-models-detector-detector--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.detector.detector">
+<span id="detector"></span><h1 id="api-graphnet-models-detector-detector--page-root">detector<a class="headerlink" href="#api-graphnet-models-detector-detector--page-root" title="Link to this heading">¶</a></h1>
+<p>Base detector-specific <cite>Model</cite> class(es).</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.detector.detector.Detector">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.detector.detector.</span></span><span class="sig-name descname"><span class="pre">Detector</span></span><a class="reference internal" href="../_modules/graphnet/models/detector/detector.html#Detector"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.detector.Detector" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p>
+<p>Base class for all detector-specific read-ins in graphnet.</p>
+<p>Construct <cite>Detector</cite>.</p>
+<dl class="field-list simple">
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.detector.detector.Detector.feature_map">
+<em class="property"><span class="pre">abstract</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">feature_map</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/detector/detector.html#Detector.feature_map"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.detector.Detector.feature_map" title="Link to this definition">¶</a></dt>
+<dd><p>List of features used/assumed by inheriting <cite>Detector</cite> objects.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>]</p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.detector.detector.Detector.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">node_features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_feature_names</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/detector/detector.html#Detector.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.detector.Detector.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Pre-process graph <cite>Data</cite> features and build graph adjacency.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>node_features</strong> (<em>tensor</em>) – </p></li>
+<li><p><strong>node_feature_names</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -519,7 +602,7 @@ <h1 id="api-graphnet-models-detector-detector--page-root">detector<a class="head
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.detector.html b/api/graphnet.models.detector.html
index 0c7904b36..d42b5269f 100644
--- a/api/graphnet.models.detector.html
+++ b/api/graphnet.models.detector.html
@@ -467,14 +467,27 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="detector">
-<h1 id="api-graphnet-models-detector--page-root">detector<a class="headerlink" href="#api-graphnet-models-detector--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.detector">
+<span id="detector"></span><h1 id="api-graphnet-models-detector--page-root">detector<a class="headerlink" href="#api-graphnet-models-detector--page-root" title="Link to this heading">¶</a></h1>
+<p>Detector-specific modules, for data ingestion and standardisation.</p>
 <p><h2> Submodules </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.detector.detector.html">detector</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.detector.icecube.html">icecube</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.detector.prometheus.html">prometheus</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.detector.detector.html">detector</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector"><code class="docutils literal notranslate"><span class="pre">Detector</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.detector.icecube.html">icecube</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCube86"><code class="docutils literal notranslate"><span class="pre">IceCube86</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeKaggle"><code class="docutils literal notranslate"><span class="pre">IceCubeKaggle</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeDeepCore"><code class="docutils literal notranslate"><span class="pre">IceCubeDeepCore</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeUpgrade"><code class="docutils literal notranslate"><span class="pre">IceCubeUpgrade</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.detector.prometheus.html">prometheus</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.detector.prometheus.html#graphnet.models.detector.prometheus.Prometheus"><code class="docutils literal notranslate"><span class="pre">Prometheus</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -526,7 +539,7 @@ <h1 id="api-graphnet-models-detector--page-root">detector<a class="headerlink" h
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.detector.icecube.html b/api/graphnet.models.detector.icecube.html
index 003c24f60..e37d2e2da 100644
--- a/api/graphnet.models.detector.icecube.html
+++ b/api/graphnet.models.detector.icecube.html
@@ -350,11 +350,96 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-detector-icecube--page-root" class="md-nav__link">icecube</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCube86" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCube86</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCube86.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeKaggle" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeKaggle</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeKaggle.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeDeepCore" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeDeepCore</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeDeepCore.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeUpgrade" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeUpgrade</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeUpgrade.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.detector.icecube.IceCube86" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCube86</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.detector.icecube.IceCube86.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.detector.icecube.IceCubeKaggle" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeKaggle</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.detector.icecube.IceCubeKaggle.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.detector.icecube.IceCubeDeepCore" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeDeepCore</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.detector.icecube.IceCubeDeepCore.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.detector.icecube.IceCubeUpgrade" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeUpgrade</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.detector.icecube.IceCubeUpgrade.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -458,7 +543,36 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-detector-icecube--page-root" class="md-nav__link">icecube</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCube86" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCube86</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCube86.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeKaggle" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeKaggle</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeKaggle.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeDeepCore" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeDeepCore</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeDeepCore.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeUpgrade" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeUpgrade</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeUpgrade.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -468,8 +582,85 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="icecube">
-<h1 id="api-graphnet-models-detector-icecube--page-root">icecube<a class="headerlink" href="#api-graphnet-models-detector-icecube--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.detector.icecube">
+<span id="icecube"></span><h1 id="api-graphnet-models-detector-icecube--page-root">icecube<a class="headerlink" href="#api-graphnet-models-detector-icecube--page-root" title="Link to this heading">¶</a></h1>
+<p>IceCube-specific <cite>Detector</cite> class(es).</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCube86">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.detector.icecube.</span></span><span class="sig-name descname"><span class="pre">IceCube86</span></span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCube86"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCube86" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector" title="graphnet.models.detector.detector.Detector"><code class="xref py py-class docutils literal notranslate"><span class="pre">Detector</span></code></a></p>
+<p><cite>Detector</cite> class for IceCube-86.</p>
+<p>Construct <cite>Detector</cite>.</p>
+<dl class="field-list simple">
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCube86.feature_map">
+<span class="sig-name descname"><span class="pre">feature_map</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCube86.feature_map"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCube86.feature_map" title="Link to this definition">¶</a></dt>
+<dd><p>Map standardization functions to each dimension of input data.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>]</p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCubeKaggle">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.detector.icecube.</span></span><span class="sig-name descname"><span class="pre">IceCubeKaggle</span></span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCubeKaggle"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCubeKaggle" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector" title="graphnet.models.detector.detector.Detector"><code class="xref py py-class docutils literal notranslate"><span class="pre">Detector</span></code></a></p>
+<p><cite>Detector</cite> class for Kaggle Competition.</p>
+<p>Construct <cite>Detector</cite>.</p>
+<dl class="field-list simple">
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCubeKaggle.feature_map">
+<span class="sig-name descname"><span class="pre">feature_map</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCubeKaggle.feature_map"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCubeKaggle.feature_map" title="Link to this definition">¶</a></dt>
+<dd><p>Map standardization functions to each dimension of input data.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>]</p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCubeDeepCore">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.detector.icecube.</span></span><span class="sig-name descname"><span class="pre">IceCubeDeepCore</span></span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCubeDeepCore"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCubeDeepCore" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector" title="graphnet.models.detector.detector.Detector"><code class="xref py py-class docutils literal notranslate"><span class="pre">Detector</span></code></a></p>
+<p><cite>Detector</cite> class for IceCube-DeepCore.</p>
+<p>Construct <cite>Detector</cite>.</p>
+<dl class="field-list simple">
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCubeDeepCore.feature_map">
+<span class="sig-name descname"><span class="pre">feature_map</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCubeDeepCore.feature_map"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCubeDeepCore.feature_map" title="Link to this definition">¶</a></dt>
+<dd><p>Map standardization functions to each dimension of input data.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>]</p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCubeUpgrade">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.detector.icecube.</span></span><span class="sig-name descname"><span class="pre">IceCubeUpgrade</span></span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCubeUpgrade"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCubeUpgrade" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector" title="graphnet.models.detector.detector.Detector"><code class="xref py py-class docutils literal notranslate"><span class="pre">Detector</span></code></a></p>
+<p><cite>Detector</cite> class for IceCube-Upgrade.</p>
+<p>Construct <cite>Detector</cite>.</p>
+<dl class="field-list simple">
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCubeUpgrade.feature_map">
+<span class="sig-name descname"><span class="pre">feature_map</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCubeUpgrade.feature_map"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCubeUpgrade.feature_map" title="Link to this definition">¶</a></dt>
+<dd><p>Map standardization functions to each dimension of input data.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>]</p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -519,7 +710,7 @@ <h1 id="api-graphnet-models-detector-icecube--page-root">icecube<a class="header
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.detector.prometheus.html b/api/graphnet.models.detector.prometheus.html
index 9b27ebab9..0c266b36a 100644
--- a/api/graphnet.models.detector.prometheus.html
+++ b/api/graphnet.models.detector.prometheus.html
@@ -357,11 +357,36 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-detector-prometheus--page-root" class="md-nav__link">prometheus</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.prometheus.Prometheus" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Prometheus</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.prometheus.Prometheus.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.detector.prometheus.Prometheus" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Prometheus</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.detector.prometheus.Prometheus.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li></ul>
     
     </li>
@@ -458,7 +483,18 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-detector-prometheus--page-root" class="md-nav__link">prometheus</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.prometheus.Prometheus" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Prometheus</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.detector.prometheus.Prometheus.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -468,8 +504,28 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="prometheus">
-<h1 id="api-graphnet-models-detector-prometheus--page-root">prometheus<a class="headerlink" href="#api-graphnet-models-detector-prometheus--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.detector.prometheus">
+<span id="prometheus"></span><h1 id="api-graphnet-models-detector-prometheus--page-root">prometheus<a class="headerlink" href="#api-graphnet-models-detector-prometheus--page-root" title="Link to this heading">¶</a></h1>
+<p>Prometheus-specific <cite>Detector</cite> class(es).</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.detector.prometheus.Prometheus">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.detector.prometheus.</span></span><span class="sig-name descname"><span class="pre">Prometheus</span></span><a class="reference internal" href="../_modules/graphnet/models/detector/prometheus.html#Prometheus"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.prometheus.Prometheus" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector" title="graphnet.models.detector.detector.Detector"><code class="xref py py-class docutils literal notranslate"><span class="pre">Detector</span></code></a></p>
+<p><cite>Detector</cite> class for Prometheus prototype.</p>
+<p>Construct <cite>Detector</cite>.</p>
+<dl class="field-list simple">
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.detector.prometheus.Prometheus.feature_map">
+<span class="sig-name descname"><span class="pre">feature_map</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/detector/prometheus.html#Prometheus.feature_map"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.prometheus.Prometheus.feature_map" title="Link to this definition">¶</a></dt>
+<dd><p>Map standardization functions to each dimension.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>]</p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -519,7 +575,7 @@ <h1 id="api-graphnet-models-detector-prometheus--page-root">prometheus<a class="
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.gnn.convnet.html b/api/graphnet.models.gnn.convnet.html
index 1d2d39cd2..52d6ba1ea 100644
--- a/api/graphnet.models.gnn.convnet.html
+++ b/api/graphnet.models.gnn.convnet.html
@@ -350,11 +350,36 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-gnn-convnet--page-root" class="md-nav__link">convnet</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.convnet.ConvNet" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ConvNet</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.convnet.ConvNet.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.gnn.convnet.ConvNet" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ConvNet</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.gnn.convnet.ConvNet.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -472,7 +497,18 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-gnn-convnet--page-root" class="md-nav__link">convnet</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.convnet.ConvNet" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ConvNet</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.convnet.ConvNet.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -482,8 +518,42 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="convnet">
-<h1 id="api-graphnet-models-gnn-convnet--page-root">convnet<a class="headerlink" href="#api-graphnet-models-gnn-convnet--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.gnn.convnet">
+<span id="convnet"></span><h1 id="api-graphnet-models-gnn-convnet--page-root">convnet<a class="headerlink" href="#api-graphnet-models-gnn-convnet--page-root" title="Link to this heading">¶</a></h1>
+<p>Implementation of the ConvNet GNN model architecture.</p>
+<p>Author: Martin Ha Minh</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.gnn.convnet.ConvNet">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.gnn.convnet.</span></span><span class="sig-name descname"><span class="pre">ConvNet</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nb_inputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nb_outputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nb_intermediate</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dropout_ratio</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/convnet.html#ConvNet"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.convnet.ConvNet" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN" title="graphnet.models.gnn.gnn.GNN"><code class="xref py py-class docutils literal notranslate"><span class="pre">GNN</span></code></a></p>
+<p>ConvNet (convolutional network) model.</p>
+<p>Construct <cite>ConvNet</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>nb_inputs</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – Number of input features, i.e. dimension of input
+layer.</p></li>
+<li><p><strong>nb_outputs</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – Number of prediction labels, i.e. dimension of
+output layer.</p></li>
+<li><p><strong>nb_intermediate</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">128</span></code>) – Number of nodes in intermediate layer(s).</p></li>
+<li><p><strong>dropout_ratio</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>, default: <code class="docutils literal notranslate"><span class="pre">0.3</span></code>) – Fraction of nodes to drop.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.gnn.convnet.ConvNet.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/convnet.html#ConvNet.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.convnet.ConvNet.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Apply learnable forward pass.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -533,7 +603,7 @@ <h1 id="api-graphnet-models-gnn-convnet--page-root">convnet<a class="headerlink"
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.gnn.dynedge.html b/api/graphnet.models.gnn.dynedge.html
index 1ae3f0c44..908284418 100644
--- a/api/graphnet.models.gnn.dynedge.html
+++ b/api/graphnet.models.gnn.dynedge.html
@@ -357,11 +357,36 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-gnn-dynedge--page-root" class="md-nav__link">dynedge</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge.DynEdge" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdge</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge.DynEdge.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.gnn.dynedge.DynEdge" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdge</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.gnn.dynedge.DynEdge.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -472,7 +497,18 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-gnn-dynedge--page-root" class="md-nav__link">dynedge</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge.DynEdge" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdge</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge.DynEdge.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -482,8 +518,64 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="dynedge">
-<h1 id="api-graphnet-models-gnn-dynedge--page-root">dynedge<a class="headerlink" href="#api-graphnet-models-gnn-dynedge--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.gnn.dynedge">
+<span id="dynedge"></span><h1 id="api-graphnet-models-gnn-dynedge--page-root">dynedge<a class="headerlink" href="#api-graphnet-models-gnn-dynedge--page-root" title="Link to this heading">¶</a></h1>
+<p>Implementation of the DynEdge GNN model architecture.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.gnn.dynedge.DynEdge">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.gnn.dynedge.</span></span><span class="sig-name descname"><span class="pre">DynEdge</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nb_inputs</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nb_neighbours</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features_subset</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dynedge_layer_sizes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">post_processing_layer_sizes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">readout_layer_sizes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">global_pooling_schemes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">add_global_variables_after_pooling</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/dynedge.html#DynEdge"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.dynedge.DynEdge" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN" title="graphnet.models.gnn.gnn.GNN"><code class="xref py py-class docutils literal notranslate"><span class="pre">GNN</span></code></a></p>
+<p>DynEdge (dynamical edge convolutional) model.</p>
+<p>Construct <cite>DynEdge</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>nb_inputs</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – Number of input features on each node.</p></li>
+<li><p><strong>nb_neighbours</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">8</span></code>) – Number of neighbours to used in the k-nearest
+neighbour clustering which is performed after each (dynamical)
+edge convolution.</p></li>
+<li><p><strong>features_subset</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], <code class="xref py py-class docutils literal notranslate"><span class="pre">slice</span></code>, <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The subset of latent features on each node that
+are used as metric dimensions when performing the k-nearest
+neighbours clustering. Defaults to [0,1,2].</p></li>
+<li><p><strong>dynedge_layer_sizes</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">...</span></code>]]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The layer sizes, or latent feature dimenions,
+used in the <cite>DynEdgeConv</cite> layer. Each entry in
+<cite>dynedge_layer_sizes</cite> corresponds to a single <cite>DynEdgeConv</cite>
+layer; the integers in the corresponding tuple corresponds to
+the layer sizes in the multi-layer perceptron (MLP) that is
+applied within each <cite>DynEdgeConv</cite> layer. That is, a list of
+size-two tuples means that all <cite>DynEdgeConv</cite> layers contain a
+two-layer MLP.
+Defaults to [(128, 256), (336, 256), (336, 256), (336, 256)].</p></li>
+<li><p><strong>post_processing_layer_sizes</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Hidden layer sizes in the MLP
+following the skip-concatenation of the outputs of each
+<cite>DynEdgeConv</cite> layer. Defaults to [336, 256].</p></li>
+<li><p><strong>readout_layer_sizes</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Hidden layer sizes in the MLP following the
+post-processing _and_ optional global pooling. As this is the
+last layer(s) in the model, the last layer in the read-out
+yields the output of the <cite>DynEdge</cite> model. Defaults to [128,].</p></li>
+<li><p><strong>global_pooling_schemes</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The list global pooling schemes to use.
+Options are: “min”, “max”, “mean”, and “sum”.</p></li>
+<li><p><strong>add_global_variables_after_pooling</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code>, default: <code class="docutils literal notranslate"><span class="pre">False</span></code>) – Whether to add global variables
+after global pooling. The alternative is to  added (distribute)
+them to the individual nodes before any convolutional
+operations.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.gnn.dynedge.DynEdge.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/dynedge.html#DynEdge.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.dynedge.DynEdge.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Apply learnable forward pass.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -533,7 +625,7 @@ <h1 id="api-graphnet-models-gnn-dynedge--page-root">dynedge<a class="headerlink"
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.gnn.dynedge_jinst.html b/api/graphnet.models.gnn.dynedge_jinst.html
index f46df1144..407d0cb08 100644
--- a/api/graphnet.models.gnn.dynedge_jinst.html
+++ b/api/graphnet.models.gnn.dynedge_jinst.html
@@ -364,11 +364,36 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-gnn-dynedge-jinst--page-root" class="md-nav__link">dynedge_jinst</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeJINST</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeJINST</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -472,7 +497,18 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-gnn-dynedge-jinst--page-root" class="md-nav__link">dynedge_jinst</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeJINST</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -482,8 +518,39 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="dynedge-jinst">
-<h1 id="api-graphnet-models-gnn-dynedge-jinst--page-root">dynedge_jinst<a class="headerlink" href="#api-graphnet-models-gnn-dynedge-jinst--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.gnn.dynedge_jinst">
+<span id="dynedge-jinst"></span><h1 id="api-graphnet-models-gnn-dynedge-jinst--page-root">dynedge_jinst<a class="headerlink" href="#api-graphnet-models-gnn-dynedge-jinst--page-root" title="Link to this heading">¶</a></h1>
+<p>Implementation of the exact DynEdge architecture used in [2209.03042].</p>
+<p>Author: Rasmus Oersoe</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.gnn.dynedge_jinst.DynEdgeJINST">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.gnn.dynedge_jinst.</span></span><span class="sig-name descname"><span class="pre">DynEdgeJINST</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nb_inputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">layer_size_scale</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/dynedge_jinst.html#DynEdgeJINST"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN" title="graphnet.models.gnn.gnn.GNN"><code class="xref py py-class docutils literal notranslate"><span class="pre">GNN</span></code></a></p>
+<p>DynEdge (dynamical edge convolutional) model used in [2209.03042].</p>
+<p>Construct <cite>DynEdgeJINST</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>nb_inputs</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – Number of input features.</p></li>
+<li><p><strong>nb_outputs</strong> – Number of output features.</p></li>
+<li><p><strong>layer_size_scale</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">4</span></code>) – Integer that scales the size of hidden layers.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/dynedge_jinst.html#DynEdgeJINST.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Apply learnable forward pass.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -533,7 +600,7 @@ <h1 id="api-graphnet-models-gnn-dynedge-jinst--page-root">dynedge_jinst<a class=
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.gnn.dynedge_kaggle_tito.html b/api/graphnet.models.gnn.dynedge_kaggle_tito.html
index 6f259d9bc..a27928192 100644
--- a/api/graphnet.models.gnn.dynedge_kaggle_tito.html
+++ b/api/graphnet.models.gnn.dynedge_kaggle_tito.html
@@ -371,11 +371,36 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-gnn-dynedge-kaggle-tito--page-root" class="md-nav__link">dynedge_kaggle_tito</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeTITO</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeTITO</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -472,7 +497,18 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-gnn-dynedge-kaggle-tito--page-root" class="md-nav__link">dynedge_kaggle_tito</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeTITO</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -482,8 +518,49 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="dynedge-kaggle-tito">
-<h1 id="api-graphnet-models-gnn-dynedge-kaggle-tito--page-root">dynedge_kaggle_tito<a class="headerlink" href="#api-graphnet-models-gnn-dynedge-kaggle-tito--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.gnn.dynedge_kaggle_tito">
+<span id="dynedge-kaggle-tito"></span><h1 id="api-graphnet-models-gnn-dynedge-kaggle-tito--page-root">dynedge_kaggle_tito<a class="headerlink" href="#api-graphnet-models-gnn-dynedge-kaggle-tito--page-root" title="Link to this heading">¶</a></h1>
+<p>Implementation of DynEdge architecture used in.</p>
+<blockquote>
+<div><p>IceCube - Neutrinos in Deep Ice</p>
+</div></blockquote>
+<p>Reconstruct the direction of neutrinos from the Universe to the South Pole</p>
+<p>Kaggle competition.</p>
+<p>Solution by TITO.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.gnn.dynedge_kaggle_tito.</span></span><span class="sig-name descname"><span class="pre">DynEdgeTITO</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="pre">nb_inputs,</span> <span class="pre">features_subset(0,</span> <span class="pre">4,</span> <span class="pre">None),</span> <span class="pre">dyntrans_layer_sizes,</span> <span class="pre">global_pooling_schemes=['max']</span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/dynedge_kaggle_tito.html#DynEdgeTITO"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN" title="graphnet.models.gnn.gnn.GNN"><code class="xref py py-class docutils literal notranslate"><span class="pre">GNN</span></code></a></p>
+<p>DynEdge (dynamical edge convolutional) model.</p>
+<p>Construct <cite>DynEdge</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>nb_inputs</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – Number of input features on each node.</p></li>
+<li><p><strong>features_subset</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">slice</span></code>, default: <code class="docutils literal notranslate"><span class="pre">slice(0,</span> <span class="pre">4,</span> <span class="pre">None)</span></code>) – The subset of latent features on each node that
+are used as metric dimensions when performing the k-nearest
+neighbours clustering. Defaults to [0,1,2,3].</p></li>
+<li><p><strong>dyntrans_layer_sizes</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">...</span></code>]]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The layer sizes, or latent feature dimenions,
+used in the <cite>DynTrans</cite> layer.</p></li>
+<li><p><strong>global_pooling_schemes</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">['max']</span></code>) – The list global pooling schemes to use.
+Options are: “min”, “max”, “mean”, and “sum”.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/dynedge_kaggle_tito.html#DynEdgeTITO.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Apply learnable forward pass.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -533,7 +610,7 @@ <h1 id="api-graphnet-models-gnn-dynedge-kaggle-tito--page-root">dynedge_kaggle_t
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.gnn.gnn.html b/api/graphnet.models.gnn.gnn.html
index a59853501..78ee2338d 100644
--- a/api/graphnet.models.gnn.gnn.html
+++ b/api/graphnet.models.gnn.gnn.html
@@ -378,11 +378,54 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-gnn-gnn--page-root" class="md-nav__link">gnn</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GNN</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN.nb_outputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_outputs</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.gnn.gnn.GNN" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GNN</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.gnn.gnn.GNN.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.gnn.gnn.GNN.nb_outputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_outputs</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.gnn.gnn.GNN.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li></ul>
     
     </li>
@@ -472,7 +515,22 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-gnn-gnn--page-root" class="md-nav__link">gnn</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GNN</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN.nb_outputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_outputs</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -482,8 +540,47 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="gnn">
-<h1 id="api-graphnet-models-gnn-gnn--page-root">gnn<a class="headerlink" href="#api-graphnet-models-gnn-gnn--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.gnn.gnn">
+<span id="gnn"></span><h1 id="api-graphnet-models-gnn-gnn--page-root">gnn<a class="headerlink" href="#api-graphnet-models-gnn-gnn--page-root" title="Link to this heading">¶</a></h1>
+<p>Base GNN-specific <cite>Model</cite> class(es).</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.gnn.gnn.GNN">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.gnn.gnn.</span></span><span class="sig-name descname"><span class="pre">GNN</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nb_inputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nb_outputs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/gnn.html#GNN"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.gnn.GNN" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p>
+<p>Base class for all core GNN models in graphnet.</p>
+<p>Construct <cite>GNN</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>nb_inputs</strong> (<em>int</em>) – </p></li>
+<li><p><strong>nb_outputs</strong> (<em>int</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.models.gnn.gnn.GNN.nb_inputs">
+<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#graphnet.models.gnn.gnn.GNN.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd><p>Return number of input features.</p>
+</dd></dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.models.gnn.gnn.GNN.nb_outputs">
+<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">nb_outputs</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#graphnet.models.gnn.gnn.GNN.nb_outputs" title="Link to this definition">¶</a></dt>
+<dd><p>Return number of output features.</p>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.gnn.gnn.GNN.forward">
+<em class="property"><span class="pre">abstract</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/gnn.html#GNN.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.gnn.GNN.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Apply learnable forward pass in model.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -533,7 +630,7 @@ <h1 id="api-graphnet-models-gnn-gnn--page-root">gnn<a class="headerlink" href="#
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.gnn.html b/api/graphnet.models.gnn.html
index a98fe7164..be8f47a7f 100644
--- a/api/graphnet.models.gnn.html
+++ b/api/graphnet.models.gnn.html
@@ -481,16 +481,32 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="gnn">
-<h1 id="api-graphnet-models-gnn--page-root">gnn<a class="headerlink" href="#api-graphnet-models-gnn--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.gnn">
+<span id="gnn"></span><h1 id="api-graphnet-models-gnn--page-root">gnn<a class="headerlink" href="#api-graphnet-models-gnn--page-root" title="Link to this heading">¶</a></h1>
+<p>GNN-specific modules, for performing the main learnable operations.</p>
 <p><h2> Submodules </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.convnet.html">convnet</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.dynedge.html">dynedge</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.dynedge_jinst.html">dynedge_jinst</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.dynedge_kaggle_tito.html">dynedge_kaggle_tito</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.gnn.html">gnn</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.convnet.html">convnet</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.gnn.convnet.html#graphnet.models.gnn.convnet.ConvNet"><code class="docutils literal notranslate"><span class="pre">ConvNet</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.dynedge.html">dynedge</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.gnn.dynedge.html#graphnet.models.gnn.dynedge.DynEdge"><code class="docutils literal notranslate"><span class="pre">DynEdge</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.dynedge_jinst.html">dynedge_jinst</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.gnn.dynedge_jinst.html#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST"><code class="docutils literal notranslate"><span class="pre">DynEdgeJINST</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.dynedge_kaggle_tito.html">dynedge_kaggle_tito</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.gnn.dynedge_kaggle_tito.html#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO"><code class="docutils literal notranslate"><span class="pre">DynEdgeTITO</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.gnn.html">gnn</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN"><code class="docutils literal notranslate"><span class="pre">GNN</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -542,7 +558,7 @@ <h1 id="api-graphnet-models-gnn--page-root">gnn<a class="headerlink" href="#api-
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.graphs.edges.edges.html b/api/graphnet.models.graphs.edges.edges.html
index ea7739e62..98b04f64a 100644
--- a/api/graphnet.models.graphs.edges.edges.html
+++ b/api/graphnet.models.graphs.edges.edges.html
@@ -363,11 +363,63 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-graphs-edges-edges--page-root" class="md-nav__link">edges</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.EdgeDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.EdgeDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.KNNEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">KNNEdges</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.RadialEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">RadialEdges</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.EuclideanEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EuclideanEdges</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.graphs.edges.edges.EdgeDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.graphs.edges.edges.EdgeDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.graphs.edges.edges.KNNEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">KNNEdges</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.graphs.edges.edges.RadialEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">RadialEdges</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.graphs.edges.edges.EuclideanEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EuclideanEdges</span></code></a>
+      
+    
+    </li></ul>
+    
     </li></ul>
     
     </li>
@@ -473,7 +525,24 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-graphs-edges-edges--page-root" class="md-nav__link">edges</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.EdgeDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.EdgeDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.KNNEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">KNNEdges</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.RadialEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">RadialEdges</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.EuclideanEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EuclideanEdges</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -483,8 +552,102 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="edges">
-<h1 id="api-graphnet-models-graphs-edges-edges--page-root">edges<a class="headerlink" href="#api-graphnet-models-graphs-edges-edges--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.graphs.edges.edges">
+<span id="edges"></span><h1 id="api-graphnet-models-graphs-edges-edges--page-root">edges<a class="headerlink" href="#api-graphnet-models-graphs-edges-edges--page-root" title="Link to this heading">¶</a></h1>
+<p>Class(es) for building/connecting graphs.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.graphs.edges.edges.EdgeDefinition">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.edges.edges.</span></span><span class="sig-name descname"><span class="pre">EdgeDefinition</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">class_name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">level</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">log_folder</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/edges/edges.html#EdgeDefinition"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.edges.edges.EdgeDefinition" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p>
+<p>Base class for graph building.</p>
+<p>Construct <cite>Logger</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>name</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>class_name</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>level</strong> (<em>int</em>) – </p></li>
+<li><p><strong>log_folder</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.graphs.edges.edges.EdgeDefinition.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">graph</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/edges/edges.html#EdgeDefinition.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.edges.edges.EdgeDefinition.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Construct edges based on problem specific implementation of.</p>
+<p>´_construct_edges´</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>graph</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code>) – a graph without edges</p>
+</dd>
+<dt class="field-even">Returns<span class="colon">:</span></dt>
+<dd class="field-even"><p>a graph with edges</p>
+</dd>
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p>graph</p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.graphs.edges.edges.KNNEdges">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.edges.edges.</span></span><span class="sig-name descname"><span class="pre">KNNEdges</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nb_nearest_neighbours</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">[0,</span> <span class="pre">1,</span> <span class="pre">2]</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/edges/edges.html#KNNEdges"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.edges.edges.KNNEdges" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.models.graphs.edges.edges.EdgeDefinition" title="graphnet.models.graphs.edges.edges.EdgeDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a></p>
+<p>Builds edges from the k-nearest neighbours.</p>
+<p>K-NN Edge definition.</p>
+<p>Will connect nodes together with their ´nb_nearest_neighbours´
+nearest neighbours in the feature space given by ´columns´.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>nb_nearest_neighbours</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – number of neighbours.</p></li>
+<li><p><strong>columns</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">[0,</span> <span class="pre">1,</span> <span class="pre">2]</span></code>) – Node features to use for distance calculation.</p></li>
+<li><p><strong>[</strong><strong>0</strong> (<em>Defaults to</em>) – </p></li>
+<li><p><strong>1</strong> – </p></li>
+<li><p><strong>2</strong><strong>]</strong><strong>.</strong> – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.graphs.edges.edges.RadialEdges">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.edges.edges.</span></span><span class="sig-name descname"><span class="pre">RadialEdges</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">radius</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">[0,</span> <span class="pre">1,</span> <span class="pre">2]</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/edges/edges.html#RadialEdges"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.edges.edges.RadialEdges" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.models.graphs.edges.edges.EdgeDefinition" title="graphnet.models.graphs.edges.edges.EdgeDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a></p>
+<p>Builds graph from a sphere of chosen radius centred at each node.</p>
+<p>Radial edges.</p>
+<p>Connects each node to other nodes that are within a sphere of
+radius ´r´ centered at the node. The feature space of ´r´ is defined
+by ´columns´</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>radius</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>) – radius of sphere</p></li>
+<li><p><strong>columns</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">[0,</span> <span class="pre">1,</span> <span class="pre">2]</span></code>) – columns of the node feature matrix used.</p></li>
+<li><p><strong>[</strong><strong>0</strong> (<em>Defaults to</em>) – </p></li>
+<li><p><strong>1</strong> – </p></li>
+<li><p><strong>2</strong><strong>]</strong><strong>.</strong> – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.graphs.edges.edges.EuclideanEdges">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.edges.edges.</span></span><span class="sig-name descname"><span class="pre">EuclideanEdges</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">sigma</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">threshold</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/edges/edges.html#EuclideanEdges"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.edges.edges.EuclideanEdges" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.models.graphs.edges.edges.EdgeDefinition" title="graphnet.models.graphs.edges.edges.EdgeDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a></p>
+<p>Builds edges according to Euclidean distance between nodes.</p>
+<p>See <a class="reference external" href="https://arxiv.org/pdf/1809.06166.pdf">https://arxiv.org/pdf/1809.06166.pdf</a>.</p>
+<p>Construct <cite>EuclideanEdges</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>sigma</strong> (<em>float</em>) – </p></li>
+<li><p><strong>threshold</strong> (<em>float</em>) – </p></li>
+<li><p><strong>columns</strong> (<em>List</em><em>[</em><em>int</em><em>]</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -534,7 +697,7 @@ <h1 id="api-graphnet-models-graphs-edges-edges--page-root">edges<a class="header
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.graphs.edges.html b/api/graphnet.models.graphs.edges.html
index e6d65d3e9..e4feb5f18 100644
--- a/api/graphnet.models.graphs.edges.html
+++ b/api/graphnet.models.graphs.edges.html
@@ -482,12 +482,22 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="edges">
-<h1 id="api-graphnet-models-graphs-edges--page-root">edges<a class="headerlink" href="#api-graphnet-models-graphs-edges--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.graphs.edges">
+<span id="edges"></span><h1 id="api-graphnet-models-graphs-edges--page-root">edges<a class="headerlink" href="#api-graphnet-models-graphs-edges--page-root" title="Link to this heading">¶</a></h1>
+<p>Modules for constructing graphs.</p>
+<p>´GraphDefinition´ defines the nodes and their features,  and contains general
+graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes
+and their features.</p>
 <p><h2> Submodules </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.edges.edges.html">edges</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.edges.edges.html">edges</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EdgeDefinition"><code class="docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.KNNEdges"><code class="docutils literal notranslate"><span class="pre">KNNEdges</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.RadialEdges"><code class="docutils literal notranslate"><span class="pre">RadialEdges</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EuclideanEdges"><code class="docutils literal notranslate"><span class="pre">EuclideanEdges</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -539,7 +549,7 @@ <h1 id="api-graphnet-models-graphs-edges--page-root">edges<a class="headerlink"
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.graphs.graph_definition.html b/api/graphnet.models.graphs.graph_definition.html
index 45116549f..8dcade719 100644
--- a/api/graphnet.models.graphs.graph_definition.html
+++ b/api/graphnet.models.graphs.graph_definition.html
@@ -371,11 +371,36 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-graphs-graph-definition--page-root" class="md-nav__link">graph_definition</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.graph_definition.GraphDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.graph_definition.GraphDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.graphs.graph_definition.GraphDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.graphs.graph_definition.GraphDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -465,7 +490,18 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-graphs-graph-definition--page-root" class="md-nav__link">graph_definition</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.graph_definition.GraphDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.graph_definition.GraphDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -475,8 +511,59 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="graph-definition">
-<h1 id="api-graphnet-models-graphs-graph-definition--page-root">graph_definition<a class="headerlink" href="#api-graphnet-models-graphs-graph-definition--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.graphs.graph_definition">
+<span id="graph-definition"></span><h1 id="api-graphnet-models-graphs-graph-definition--page-root">graph_definition<a class="headerlink" href="#api-graphnet-models-graphs-graph-definition--page-root" title="Link to this heading">¶</a></h1>
+<p>Modules for defining graphs.</p>
+<p>These are self-contained graph definitions that hold all the graph-altering
+code in graphnet. These modules define what the GNNs sees as input and can be
+passed to dataloaders during training and deployment.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.graphs.graph_definition.GraphDefinition">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.graph_definition.</span></span><span class="sig-name descname"><span class="pre">GraphDefinition</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">detector</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">edge_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_feature_names</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/graph_definition.html#GraphDefinition"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.graph_definition.GraphDefinition" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p>
+<p>An Abstract class to create graph definitions from.</p>
+<p>Construct ´GraphDefinition´. The ´detector´ holds.</p>
+<p>´Detector´-specific code. E.g. scaling/standardization and geometry
+tables.</p>
+<p>´node_definition´ defines the nodes in the graph.</p>
+<p>´edge_definition´ defines the connectivity of the nodes in the graph.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>detector</strong> (<a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector" title="graphnet.models.detector.detector.Detector"><code class="xref py py-class docutils literal notranslate"><span class="pre">Detector</span></code></a>) – The corresponding ´Detector´ representing the data.</p></li>
+<li><p><strong>node_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition" title="graphnet.models.graphs.nodes.nodes.NodeDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">NodeDefinition</span></code></a>) – Definition of nodes.</p></li>
+<li><p><strong>edge_definition</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<a class="reference internal" href="graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EdgeDefinition" title="graphnet.models.graphs.edges.edges.EdgeDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Definition of edges. Defaults to None.</p></li>
+<li><p><strong>node_feature_names</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Names of node feature columns. Defaults to None</p></li>
+<li><p><strong>dtype</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">dtype</span></code>], default: <code class="docutils literal notranslate"><span class="pre">torch.float32</span></code>) – data type used for node features. e.g. ´torch.float´</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.graphs.graph_definition.GraphDefinition.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">node_features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_feature_names</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_dicts</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">custom_label_functions</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_default_value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">data_path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/graph_definition.html#GraphDefinition.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.graph_definition.GraphDefinition.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Construct graph as ´Data´ object.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>node_features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code>) – node features for graph. Shape ´[num_nodes, d]´</p></li>
+<li><p><strong>node_feature_names</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – name of each column. Shape ´[,d]´.</p></li>
+<li><p><strong>truth_dicts</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>]]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Dictionary containing truth labels.</p></li>
+<li><p><strong>custom_label_functions</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">...</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>]]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Custom label functions. See <a class="reference external" href="https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels">https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels</a>.</p></li>
+<li><p><strong>loss_weight_column</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of column that holds loss weight. Defaults to None.</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Loss weight associated with event. Defaults to None.</p></li>
+<li><p><strong>loss_weight_default_value</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – default value for loss weight. Used in instances where some events have no pre-defined loss weight. Defaults to None.</p></li>
+<li><p><strong>data_path</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Path to dataset data files. Defaults to None.</p></li>
+</ul>
+</dd>
+<dt class="field-even">Return type<span class="colon">:</span></dt>
+<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code></p>
+</dd>
+<dt class="field-odd">Returns<span class="colon">:</span></dt>
+<dd class="field-odd"><p>graph</p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -526,7 +613,7 @@ <h1 id="api-graphnet-models-graphs-graph-definition--page-root">graph_definition
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.graphs.graphs.html b/api/graphnet.models.graphs.graphs.html
index ca229295b..0d889a033 100644
--- a/api/graphnet.models.graphs.graphs.html
+++ b/api/graphnet.models.graphs.graphs.html
@@ -378,11 +378,25 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-graphs-graphs--page-root" class="md-nav__link">graphs</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.graphs.KNNGraph" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">KNNGraph</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.graphs.graphs.KNNGraph" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">KNNGraph</span></code></a>
       
     
+    </li></ul>
+    
     </li></ul>
     
     </li>
@@ -465,7 +479,14 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-graphs-graphs--page-root" class="md-nav__link">graphs</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.graphs.KNNGraph" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">KNNGraph</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -475,8 +496,31 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="graphs">
-<h1 id="api-graphnet-models-graphs-graphs--page-root">graphs<a class="headerlink" href="#api-graphnet-models-graphs-graphs--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.graphs.graphs">
+<span id="graphs"></span><h1 id="api-graphnet-models-graphs-graphs--page-root">graphs<a class="headerlink" href="#api-graphnet-models-graphs-graphs--page-root" title="Link to this heading">¶</a></h1>
+<p>A module containing different graph representations in GraphNeT.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.graphs.graphs.KNNGraph">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.graphs.</span></span><span class="sig-name descname"><span class="pre">KNNGraph</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">detector</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_feature_names</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nb_nearest_neighbours</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">[0,</span> <span class="pre">1,</span> <span class="pre">2]</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/graphs.html#KNNGraph"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.graphs.KNNGraph" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a></p>
+<p>A Graph representation where Edges are drawn to nearest neighbours.</p>
+<p>Construct k-nn graph representation.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>detector</strong> (<a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector" title="graphnet.models.detector.detector.Detector"><code class="xref py py-class docutils literal notranslate"><span class="pre">Detector</span></code></a>) – Detector that represents your data.</p></li>
+<li><p><strong>node_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition" title="graphnet.models.graphs.nodes.nodes.NodeDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">NodeDefinition</span></code></a>) – Definition of nodes in the graph.</p></li>
+<li><p><strong>node_feature_names</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of node features.</p></li>
+<li><p><strong>dtype</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">dtype</span></code>], default: <code class="docutils literal notranslate"><span class="pre">torch.float32</span></code>) – data type for node features.</p></li>
+<li><p><strong>nb_nearest_neighbours</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">8</span></code>) – Number of edges for each node. Defaults to 8.</p></li>
+<li><p><strong>columns</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">[0,</span> <span class="pre">1,</span> <span class="pre">2]</span></code>) – node feature columns used for distance calculation</p></li>
+<li><p><strong>[</strong><strong>0</strong> (<em>. Defaults to</em>) – </p></li>
+<li><p><strong>1</strong> – </p></li>
+<li><p><strong>2</strong><strong>]</strong><strong>.</strong> – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -526,7 +570,7 @@ <h1 id="api-graphnet-models-graphs-graphs--page-root">graphs<a class="headerlink
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.graphs.html b/api/graphnet.models.graphs.html
index fef0be82e..fbbc2ff3e 100644
--- a/api/graphnet.models.graphs.html
+++ b/api/graphnet.models.graphs.html
@@ -474,8 +474,12 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="graphs">
-<h1 id="api-graphnet-models-graphs--page-root">graphs<a class="headerlink" href="#api-graphnet-models-graphs--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.graphs">
+<span id="graphs"></span><h1 id="api-graphnet-models-graphs--page-root">graphs<a class="headerlink" href="#api-graphnet-models-graphs--page-root" title="Link to this heading">¶</a></h1>
+<p>Modules for constructing graphs.</p>
+<p>´GraphDefinition´ defines the nodes and their features,  and contains general
+graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes
+and their features.</p>
 <p><h2> Subpackages </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
@@ -492,8 +496,14 @@ <h1 id="api-graphnet-models-graphs--page-root">graphs<a class="headerlink" href=
 <p><h2> Submodules </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.graph_definition.html">graph_definition</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.graphs.html">graphs</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.graph_definition.html">graph_definition</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition"><code class="docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.graphs.html">graphs</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.graphs.html#graphnet.models.graphs.graphs.KNNGraph"><code class="docutils literal notranslate"><span class="pre">KNNGraph</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -545,7 +555,7 @@ <h1 id="api-graphnet-models-graphs--page-root">graphs<a class="headerlink" href=
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.graphs.nodes.html b/api/graphnet.models.graphs.nodes.html
index 1609b8099..2e99388bb 100644
--- a/api/graphnet.models.graphs.nodes.html
+++ b/api/graphnet.models.graphs.nodes.html
@@ -482,12 +482,20 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="nodes">
-<h1 id="api-graphnet-models-graphs-nodes--page-root">nodes<a class="headerlink" href="#api-graphnet-models-graphs-nodes--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.graphs.nodes">
+<span id="nodes"></span><h1 id="api-graphnet-models-graphs-nodes--page-root">nodes<a class="headerlink" href="#api-graphnet-models-graphs-nodes--page-root" title="Link to this heading">¶</a></h1>
+<p>Modules for constructing graphs.</p>
+<p>´GraphDefinition´ defines the nodes and their features,  and contains general
+graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes
+and their features.</p>
 <p><h2> Submodules </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.nodes.nodes.html">nodes</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.nodes.nodes.html">nodes</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition"><code class="docutils literal notranslate"><span class="pre">NodeDefinition</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodesAsPulses"><code class="docutils literal notranslate"><span class="pre">NodesAsPulses</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -539,7 +547,7 @@ <h1 id="api-graphnet-models-graphs-nodes--page-root">nodes<a class="headerlink"
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.graphs.nodes.nodes.html b/api/graphnet.models.graphs.nodes.nodes.html
index 23b5575bd..a5ac7ba2d 100644
--- a/api/graphnet.models.graphs.nodes.nodes.html
+++ b/api/graphnet.models.graphs.nodes.nodes.html
@@ -370,11 +370,63 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-graphs-nodes-nodes--page-root" class="md-nav__link">nodes</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">NodeDefinition</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_outputs</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">set_number_of_inputs()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodesAsPulses" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">NodesAsPulses</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">NodeDefinition</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_outputs</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">set_number_of_inputs()</span></code></a>
       
     
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.graphs.nodes.nodes.NodesAsPulses" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">NodesAsPulses</span></code></a>
+      
+    
+    </li></ul>
+    
     </li></ul>
     
     </li>
@@ -473,7 +525,24 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-graphs-nodes-nodes--page-root" class="md-nav__link">nodes</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">NodeDefinition</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_outputs</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">set_number_of_inputs()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodesAsPulses" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">NodesAsPulses</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -483,8 +552,65 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="nodes">
-<h1 id="api-graphnet-models-graphs-nodes-nodes--page-root">nodes<a class="headerlink" href="#api-graphnet-models-graphs-nodes-nodes--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.graphs.nodes.nodes">
+<span id="nodes"></span><h1 id="api-graphnet-models-graphs-nodes-nodes--page-root">nodes<a class="headerlink" href="#api-graphnet-models-graphs-nodes-nodes--page-root" title="Link to this heading">¶</a></h1>
+<p>Class(es) for building/connecting graphs.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.graphs.nodes.nodes.NodeDefinition">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.nodes.nodes.</span></span><span class="sig-name descname"><span class="pre">NodeDefinition</span></span><a class="reference internal" href="../_modules/graphnet/models/graphs/nodes/nodes.html#NodeDefinition"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.nodes.nodes.NodeDefinition" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p>
+<p>Base class for graph building.</p>
+<p>Construct <cite>Detector</cite>.</p>
+<dl class="field-list simple">
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.graphs.nodes.nodes.NodeDefinition.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/nodes/nodes.html#NodeDefinition.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Construct nodes from raw node features.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>x</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">tensor</span></code>) – standardized node features with shape ´[num_pulses, d]´,</p></li>
+<li><p><strong>features.</strong> (<em>where ´d´ is the number</em><em> of </em><em>node</em>) – </p></li>
+</ul>
+</dd>
+<dt class="field-even">Returns<span class="colon">:</span></dt>
+<dd class="field-even"><p>a graph without edges</p>
+</dd>
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p>graph</p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs">
+<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">nb_outputs</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs" title="Link to this definition">¶</a></dt>
+<dd><p>Return number of output features.</p>
+<p>This the default, but may be overridden by specific inheriting classes.</p>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs">
+<span class="sig-name descname"><span class="pre">set_number_of_inputs</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">node_feature_names</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/nodes/nodes.html#NodeDefinition.set_number_of_inputs"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs" title="Link to this definition">¶</a></dt>
+<dd><p>Return number of inputs expected by node definition.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>node_feature_names</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – name of each node feature column.</p>
+</dd>
+<dt class="field-even">Return type<span class="colon">:</span></dt>
+<dd class="field-even"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.graphs.nodes.nodes.NodesAsPulses">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.nodes.nodes.</span></span><span class="sig-name descname"><span class="pre">NodesAsPulses</span></span><a class="reference internal" href="../_modules/graphnet/models/graphs/nodes/nodes.html#NodesAsPulses"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.nodes.nodes.NodesAsPulses" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.models.graphs.nodes.nodes.NodeDefinition" title="graphnet.models.graphs.nodes.nodes.NodeDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">NodeDefinition</span></code></a></p>
+<p>Represent each measured pulse of Cherenkov Radiation as a node.</p>
+<p>Construct <cite>Detector</cite>.</p>
+<dl class="field-list simple">
+</dl>
+</dd></dl>
 </section>
 
 
@@ -534,7 +660,7 @@ <h1 id="api-graphnet-models-graphs-nodes-nodes--page-root">nodes<a class="header
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.html b/api/graphnet.models.html
index fa38a95a6..d514e1468 100644
--- a/api/graphnet.models.html
+++ b/api/graphnet.models.html
@@ -445,8 +445,14 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="models">
-<h1 id="api-graphnet-models--page-root">models<a class="headerlink" href="#api-graphnet-models--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models">
+<span id="models"></span><h1 id="api-graphnet-models--page-root">models<a class="headerlink" href="#api-graphnet-models--page-root" title="Link to this heading">¶</a></h1>
+<p>Modules for configuring and building models.</p>
+<p><cite>graphnet.models</cite> allows for configuring and building complex GNN models using
+simple, physics-oriented components. This module provides modular components
+subclassing <cite>torch.nn.Module</cite>, meaning that users only need to import a few,
+existing, purpose-built components and chain them together to form a complete
+GNN</p>
 <p><h2> Subpackages </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
@@ -487,10 +493,29 @@ <h1 id="api-graphnet-models--page-root">models<a class="headerlink" href="#api-g
 <p><h2> Submodules </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.coarsening.html">coarsening</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.model.html">model</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.standard_model.html">standard_model</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.utils.html">utils</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.coarsening.html">coarsening</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.coarsening.html#graphnet.models.coarsening.unbatch_edge_index"><code class="docutils literal notranslate"><span class="pre">unbatch_edge_index()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.coarsening.html#graphnet.models.coarsening.Coarsening"><code class="docutils literal notranslate"><span class="pre">Coarsening</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.coarsening.html#graphnet.models.coarsening.AttributeCoarsening"><code class="docutils literal notranslate"><span class="pre">AttributeCoarsening</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.coarsening.html#graphnet.models.coarsening.DOMCoarsening"><code class="docutils literal notranslate"><span class="pre">DOMCoarsening</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.coarsening.html#graphnet.models.coarsening.CustomDOMCoarsening"><code class="docutils literal notranslate"><span class="pre">CustomDOMCoarsening</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.coarsening.html#graphnet.models.coarsening.DOMAndTimeWindowCoarsening"><code class="docutils literal notranslate"><span class="pre">DOMAndTimeWindowCoarsening</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.model.html">model</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model"><code class="docutils literal notranslate"><span class="pre">Model</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.standard_model.html">standard_model</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel"><code class="docutils literal notranslate"><span class="pre">StandardModel</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.utils.html">utils</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.utils.html#graphnet.models.utils.calculate_xyzt_homophily"><code class="docutils literal notranslate"><span class="pre">calculate_xyzt_homophily()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.utils.html#graphnet.models.utils.calculate_distance_matrix"><code class="docutils literal notranslate"><span class="pre">calculate_distance_matrix()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.utils.html#graphnet.models.utils.knn_graph_batch"><code class="docutils literal notranslate"><span class="pre">knn_graph_batch()</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -542,7 +567,7 @@ <h1 id="api-graphnet-models--page-root">models<a class="headerlink" href="#api-g
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.model.html b/api/graphnet.models.model.html
index f1bc235f8..286f23f20 100644
--- a/api/graphnet.models.model.html
+++ b/api/graphnet.models.model.html
@@ -372,11 +372,108 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-model--page-root" class="md-nav__link">model</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Model</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.fit" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">fit()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.predict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.predict_as_dataframe" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict_as_dataframe()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.save" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.load" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.save_state_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_state_dict()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.load_state_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load_state_dict()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.model.Model" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Model</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.model.Model.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.model.Model.fit" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">fit()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.model.Model.predict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.model.Model.predict_as_dataframe" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict_as_dataframe()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.model.Model.save" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.model.Model.load" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.model.Model.save_state_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_state_dict()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.model.Model.load_state_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load_state_dict()</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.model.Model.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -436,7 +533,34 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-model--page-root" class="md-nav__link">model</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Model</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.fit" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">fit()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.predict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.predict_as_dataframe" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict_as_dataframe()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.save" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.load" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.save_state_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_state_dict()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.load_state_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load_state_dict()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.model.Model.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -446,8 +570,183 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="model">
-<h1 id="api-graphnet-models-model--page-root">model<a class="headerlink" href="#api-graphnet-models-model--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.model">
+<span id="model"></span><h1 id="api-graphnet-models-model--page-root">model<a class="headerlink" href="#api-graphnet-models-model--page-root" title="Link to this heading">¶</a></h1>
+<p>Base class(es) for building models.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.model.Model">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.model.</span></span><span class="sig-name descname"><span class="pre">Model</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">class_name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">level</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">log_folder</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.utilities.logging.html#graphnet.utilities.logging.Logger" title="graphnet.utilities.logging.Logger"><code class="xref py py-class docutils literal notranslate"><span class="pre">Logger</span></code></a>, <a class="reference internal" href="graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable" title="graphnet.utilities.config.configurable.Configurable"><code class="xref py py-class docutils literal notranslate"><span class="pre">Configurable</span></code></a>, <code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">ABC</span></code></p>
+<p>Base class for all models in graphnet.</p>
+<p>Construct <cite>Logger</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>name</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>class_name</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>level</strong> (<em>int</em>) – </p></li>
+<li><p><strong>log_folder</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.model.Model.forward">
+<em class="property"><span class="pre">abstract</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Forward pass.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code>]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>x</strong> (<em>Tensor</em><em> | </em><em>Data</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.model.Model.fit">
+<span class="sig-name descname"><span class="pre">fit</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">train_dataloader</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">val_dataloader</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">max_epochs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gpus</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">callbacks</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ckpt_path</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">logger</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">log_every_n_steps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gradient_clip_val</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">distribution_strategy</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">trainer_kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.fit"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.fit" title="Link to this definition">¶</a></dt>
+<dd><p>Fit <cite>Model</cite> using <cite>pytorch_lightning.Trainer</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>train_dataloader</strong> (<em>DataLoader</em>) – </p></li>
+<li><p><strong>val_dataloader</strong> (<em>DataLoader</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>max_epochs</strong> (<em>int</em>) – </p></li>
+<li><p><strong>gpus</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>int</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>callbacks</strong> (<em>List</em><em>[</em><em>Callback</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+<li><p><strong>ckpt_path</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>logger</strong> (<em>Logger</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>log_every_n_steps</strong> (<em>int</em>) – </p></li>
+<li><p><strong>gradient_clip_val</strong> (<em>float</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>distribution_strategy</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>trainer_kwargs</strong> (<em>Any</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.model.Model.predict">
+<span class="sig-name descname"><span class="pre">predict</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">dataloader</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gpus</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">distribution_strategy</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.predict"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.predict" title="Link to this definition">¶</a></dt>
+<dd><p>Return predictions for <cite>dataloader</cite>.</p>
+<p>Returns a list of Tensors, one for each model output.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>dataloader</strong> (<em>DataLoader</em>) – </p></li>
+<li><p><strong>gpus</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>int</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>distribution_strategy</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.model.Model.predict_as_dataframe">
+<span class="sig-name descname"><span class="pre">predict_as_dataframe</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">dataloader</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_columns</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">additional_attributes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gpus</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">distribution_strategy</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.predict_as_dataframe"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.predict_as_dataframe" title="Link to this definition">¶</a></dt>
+<dd><p>Return predictions for <cite>dataloader</cite> as a DataFrame.</p>
+<p>Include <cite>additional_attributes</cite> as additional columns in the output
+DataFrame.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">DataFrame</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>dataloader</strong> (<em>DataLoader</em>) – </p></li>
+<li><p><strong>prediction_columns</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>additional_attributes</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+<li><p><strong>gpus</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>int</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>distribution_strategy</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.model.Model.save">
+<span class="sig-name descname"><span class="pre">save</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.save"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.save" title="Link to this definition">¶</a></dt>
+<dd><p>Save entire model to <cite>path</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>path</strong> (<em>str</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.model.Model.load">
+<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">load</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.load"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.load" title="Link to this definition">¶</a></dt>
+<dd><p>Load entire model from <cite>path</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><a class="reference internal" href="#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>path</strong> (<em>str</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.model.Model.save_state_dict">
+<span class="sig-name descname"><span class="pre">save_state_dict</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.save_state_dict"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.save_state_dict" title="Link to this definition">¶</a></dt>
+<dd><p>Save model <cite>state_dict</cite> to <cite>path</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>path</strong> (<em>str</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.model.Model.load_state_dict">
+<span class="sig-name descname"><span class="pre">load_state_dict</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.load_state_dict"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.load_state_dict" title="Link to this definition">¶</a></dt>
+<dd><p>Load model <cite>state_dict</cite> from <cite>path</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><a class="reference internal" href="#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>path</strong> (<em>str</em><em> | </em><em>Dict</em>) – </p></li>
+<li><p><strong>kargs</strong> (<em>Any</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.model.Model.from_config">
+<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">from_config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">source</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">trust</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">load_modules</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.from_config"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.from_config" title="Link to this definition">¶</a></dt>
+<dd><p>Construct <cite>Model</cite> instance from <cite>source</cite> configuration.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>trust</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code>, default: <code class="docutils literal notranslate"><span class="pre">False</span></code>) – Whether to trust the ModelConfig file enough to <cite>eval(…)</cite>
+any lambda function expressions contained.</p></li>
+<li><p><strong>load_modules</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – List of modules used in the definition of the model
+which, as a consequence, need to be loaded into the global
+namespace. Defaults to loading <cite>torch</cite>.</p></li>
+<li><p><strong>source</strong> (<a class="reference internal" href="graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig" title="graphnet.utilities.config.model_config.ModelConfig"><em>ModelConfig</em></a><em> | </em><em>str</em>) – </p></li>
+</ul>
+</dd>
+<dt class="field-even">Raises<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>ValueError</strong> – If the ModelConfig contains lambda functions but
+    <cite>trust = False</cite>.</p>
+</dd>
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><a class="reference internal" href="#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -497,7 +796,7 @@ <h1 id="api-graphnet-models-model--page-root">model<a class="headerlink" href="#
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.standard_model.html b/api/graphnet.models.standard_model.html
index da580b529..b886a0725 100644
--- a/api/graphnet.models.standard_model.html
+++ b/api/graphnet.models.standard_model.html
@@ -379,11 +379,135 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-standard-model--page-root" class="md-nav__link">standard_model</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">StandardModel</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.configure_optimizers" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">configure_optimizers()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.shared_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">shared_step()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.training_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">training_step()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.validation_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">validation_step()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.compute_loss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">compute_loss()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.inference" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">inference()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.train" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">train()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.predict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.predict_as_dataframe" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict_as_dataframe()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.standard_model.StandardModel" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">StandardModel</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.standard_model.StandardModel.target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.standard_model.StandardModel.prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.standard_model.StandardModel.configure_optimizers" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">configure_optimizers()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.standard_model.StandardModel.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.standard_model.StandardModel.shared_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">shared_step()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.standard_model.StandardModel.training_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">training_step()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.standard_model.StandardModel.validation_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">validation_step()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.standard_model.StandardModel.compute_loss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">compute_loss()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.standard_model.StandardModel.inference" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">inference()</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.standard_model.StandardModel.train" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">train()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.standard_model.StandardModel.predict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.standard_model.StandardModel.predict_as_dataframe" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict_as_dataframe()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -436,7 +560,40 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-standard-model--page-root" class="md-nav__link">standard_model</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">StandardModel</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.configure_optimizers" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">configure_optimizers()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.shared_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">shared_step()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.training_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">training_step()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.validation_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">validation_step()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.compute_loss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">compute_loss()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.inference" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">inference()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.train" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">train()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.predict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.predict_as_dataframe" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict_as_dataframe()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -446,8 +603,193 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="standard-model">
-<h1 id="api-graphnet-models-standard-model--page-root">standard_model<a class="headerlink" href="#api-graphnet-models-standard-model--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.standard_model">
+<span id="standard-model"></span><h1 id="api-graphnet-models-standard-model--page-root">standard_model<a class="headerlink" href="#api-graphnet-models-standard-model--page-root" title="Link to this heading">¶</a></h1>
+<p>Standard model class(es).</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.standard_model.</span></span><span class="sig-name descname"><span class="pre">StandardModel</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gnn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tasks</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">optimizer_class=&lt;class</span> <span class="pre">'torch.optim.adam.Adam'&gt;</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">optimizer_kwargs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scheduler_class</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scheduler_kwargs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scheduler_config</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p>
+<p>Main class for standard models in graphnet.</p>
+<p>This class chains together the different elements of a complete GNN-based
+model (detector read-in, GNN architecture, and task-specific read-outs).</p>
+<p>Construct <cite>StandardModel</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>graph_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><em>GraphDefinition</em></a>) – </p></li>
+<li><p><strong>gnn</strong> (<a class="reference internal" href="graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN" title="graphnet.models.gnn.gnn.GNN"><em>GNN</em></a>) – </p></li>
+<li><p><strong>tasks</strong> (<a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><em>Task</em></a><em> | </em><em>List</em><em>[</em><a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><em>Task</em></a><em>]</em>) – </p></li>
+<li><p><strong>optimizer_class</strong> (<em>type</em>) – </p></li>
+<li><p><strong>optimizer_kwargs</strong> (<em>Dict</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>scheduler_class</strong> (<em>type</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>scheduler_kwargs</strong> (<em>Dict</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>scheduler_config</strong> (<em>Dict</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.target_labels">
+<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">target_labels</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.target_labels" title="Link to this definition">¶</a></dt>
+<dd><p>Return target label.</p>
+</dd></dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.prediction_labels">
+<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">prediction_labels</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.prediction_labels" title="Link to this definition">¶</a></dt>
+<dd><p>Return prediction labels.</p>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.configure_optimizers">
+<span class="sig-name descname"><span class="pre">configure_optimizers</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.configure_optimizers"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.configure_optimizers" title="Link to this definition">¶</a></dt>
+<dd><p>Configure the model’s optimizer(s).</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>]</p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Forward pass, chaining model components.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code>]]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.shared_step">
+<span class="sig-name descname"><span class="pre">shared_step</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">batch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_idx</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.shared_step"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.shared_step" title="Link to this definition">¶</a></dt>
+<dd><p>Perform shared step.</p>
+<p>Applies the forward pass and the following loss calculation, shared
+between the training and validation step.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>batch</strong> (<em>Data</em>) – </p></li>
+<li><p><strong>batch_idx</strong> (<em>int</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.training_step">
+<span class="sig-name descname"><span class="pre">training_step</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">train_batch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_idx</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.training_step"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.training_step" title="Link to this definition">¶</a></dt>
+<dd><p>Perform training step.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>train_batch</strong> (<em>Data</em>) – </p></li>
+<li><p><strong>batch_idx</strong> (<em>int</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.validation_step">
+<span class="sig-name descname"><span class="pre">validation_step</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">val_batch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_idx</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.validation_step"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.validation_step" title="Link to this definition">¶</a></dt>
+<dd><p>Perform validation step.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>val_batch</strong> (<em>Data</em>) – </p></li>
+<li><p><strong>batch_idx</strong> (<em>int</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.compute_loss">
+<span class="sig-name descname"><span class="pre">compute_loss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">data</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">verbose</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.compute_loss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.compute_loss" title="Link to this definition">¶</a></dt>
+<dd><p>Compute and sum losses across tasks.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>preds</strong> (<em>Tensor</em>) – </p></li>
+<li><p><strong>data</strong> (<em>Data</em>) – </p></li>
+<li><p><strong>verbose</strong> (<em>bool</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.inference">
+<span class="sig-name descname"><span class="pre">inference</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.inference"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.inference" title="Link to this definition">¶</a></dt>
+<dd><p>Activate inference mode.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.train">
+<span class="sig-name descname"><span class="pre">train</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">mode</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.train"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.train" title="Link to this definition">¶</a></dt>
+<dd><p>Deactivate inference mode.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>mode</strong> (<em>bool</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.predict">
+<span class="sig-name descname"><span class="pre">predict</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">dataloader</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gpus</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">distribution_strategy</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.predict"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.predict" title="Link to this definition">¶</a></dt>
+<dd><p>Return predictions for <cite>dataloader</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>dataloader</strong> (<em>DataLoader</em>) – </p></li>
+<li><p><strong>gpus</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>int</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>distribution_strategy</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.predict_as_dataframe">
+<span class="sig-name descname"><span class="pre">predict_as_dataframe</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">dataloader</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_columns</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">additional_attributes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gpus</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">distribution_strategy</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.predict_as_dataframe"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.predict_as_dataframe" title="Link to this definition">¶</a></dt>
+<dd><p>Return predictions for <cite>dataloader</cite> as a DataFrame.</p>
+<p>Include <cite>additional_attributes</cite> as additional columns in the output
+DataFrame.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">DataFrame</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>dataloader</strong> (<em>DataLoader</em>) – </p></li>
+<li><p><strong>prediction_columns</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+<li><p><strong>additional_attributes</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+<li><p><strong>gpus</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>int</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>distribution_strategy</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -497,7 +839,7 @@ <h1 id="api-graphnet-models-standard-model--page-root">standard_model<a class="h
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.task.classification.html b/api/graphnet.models.task.classification.html
index 3fb005f08..50cf9ca15 100644
--- a/api/graphnet.models.task.classification.html
+++ b/api/graphnet.models.task.classification.html
@@ -364,11 +364,101 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-task-classification--page-root" class="md-nav__link">classification</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.MulticlassClassificationTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">MulticlassClassificationTask</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTask</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTaskLogits</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.classification.MulticlassClassificationTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">MulticlassClassificationTask</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.classification.BinaryClassificationTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTask</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.classification.BinaryClassificationTask.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.classification.BinaryClassificationTask.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTaskLogits</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -458,7 +548,34 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-task-classification--page-root" class="md-nav__link">classification</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.MulticlassClassificationTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">MulticlassClassificationTask</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTask</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTaskLogits</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -468,8 +585,147 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="classification">
-<h1 id="api-graphnet-models-task-classification--page-root">classification<a class="headerlink" href="#api-graphnet-models-task-classification--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.task.classification">
+<span id="classification"></span><h1 id="api-graphnet-models-task-classification--page-root">classification<a class="headerlink" href="#api-graphnet-models-task-classification--page-root" title="Link to this heading">¶</a></h1>
+<p>Classification-specific <cite>Model</cite> class(es).</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.classification.MulticlassClassificationTask">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.classification.</span></span><span class="sig-name descname"><span class="pre">MulticlassClassificationTask</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nb_outputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/classification.html#MulticlassClassificationTask"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.classification.MulticlassClassificationTask" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.IdentityTask" title="graphnet.models.task.task.IdentityTask"><code class="xref py py-class docutils literal notranslate"><span class="pre">IdentityTask</span></code></a></p>
+<p>General task for classifying any number of classes.</p>
+<p>Requires the same number of input features as the number of classes being
+predicted. Returns the untransformed latent features, which are interpreted
+as the logits for each class being classified.</p>
+<p>Construct IdentityTask.</p>
+<p>Return the <cite>nb_outputs</cite> as a direct, affine transformation of the last
+hidden layer.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>nb_outputs</strong> (<em>int</em>) – </p></li>
+<li><p><strong>target_labels</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>Any</em>) – </p></li>
+<li><p><strong>args</strong> (<em>Any</em>) – </p></li>
+<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTask">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.classification.</span></span><span class="sig-name descname"><span class="pre">BinaryClassificationTask</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/classification.html#BinaryClassificationTask"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTask" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p>
+<p>Performs binary classification.</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTask.nb_inputs">
+<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">1</span></em><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTask.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTask.default_target_labels">
+<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['target']</span></em><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTask.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels">
+<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['target_pred']</span></em><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTaskLogits">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.classification.</span></span><span class="sig-name descname"><span class="pre">BinaryClassificationTaskLogits</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/classification.html#BinaryClassificationTaskLogits"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTaskLogits" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p>
+<p>Performs binary classification form logits.</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs">
+<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">1</span></em><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels">
+<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['target']</span></em><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels">
+<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['target_pred']</span></em><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+</dd></dl>
 </section>
 
 
@@ -519,7 +775,7 @@ <h1 id="api-graphnet-models-task-classification--page-root">classification<a cla
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.task.html b/api/graphnet.models.task.html
index 1df19c8e8..0ed5f54b5 100644
--- a/api/graphnet.models.task.html
+++ b/api/graphnet.models.task.html
@@ -467,14 +467,38 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="task">
-<h1 id="api-graphnet-models-task--page-root">task<a class="headerlink" href="#api-graphnet-models-task--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.task">
+<span id="task"></span><h1 id="api-graphnet-models-task--page-root">task<a class="headerlink" href="#api-graphnet-models-task--page-root" title="Link to this heading">¶</a></h1>
+<p>Physics task-specific modules to be used as model “read-outs”.</p>
 <p><h2> Submodules </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.task.classification.html">classification</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.task.reconstruction.html">reconstruction</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.models.task.task.html">task</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.task.classification.html">classification</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.classification.html#graphnet.models.task.classification.MulticlassClassificationTask"><code class="docutils literal notranslate"><span class="pre">MulticlassClassificationTask</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTask"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTask</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTaskLogits"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTaskLogits</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.task.reconstruction.html">reconstruction</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstructionWithKappa</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstruction"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstruction</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa"><code class="docutils literal notranslate"><span class="pre">DirectionReconstructionWithKappa</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstruction"><code class="docutils literal notranslate"><span class="pre">ZenithReconstruction</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa"><code class="docutils literal notranslate"><span class="pre">ZenithReconstructionWithKappa</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstruction"><code class="docutils literal notranslate"><span class="pre">EnergyReconstruction</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithPower"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithPower</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithUncertainty</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.VertexReconstruction"><code class="docutils literal notranslate"><span class="pre">VertexReconstruction</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.PositionReconstruction"><code class="docutils literal notranslate"><span class="pre">PositionReconstruction</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.TimeReconstruction"><code class="docutils literal notranslate"><span class="pre">TimeReconstruction</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.InelasticityReconstruction"><code class="docutils literal notranslate"><span class="pre">InelasticityReconstruction</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.models.task.task.html">task</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task"><code class="docutils literal notranslate"><span class="pre">Task</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.IdentityTask"><code class="docutils literal notranslate"><span class="pre">IdentityTask</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -526,7 +550,7 @@ <h1 id="api-graphnet-models-task--page-root">task<a class="headerlink" href="#ap
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.task.reconstruction.html b/api/graphnet.models.task.reconstruction.html
index 1d48b9eb0..06b772e83 100644
--- a/api/graphnet.models.task.reconstruction.html
+++ b/api/graphnet.models.task.reconstruction.html
@@ -371,11 +371,472 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-task-reconstruction--page-root" class="md-nav__link">reconstruction</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstructionWithKappa</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstruction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DirectionReconstructionWithKappa</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ZenithReconstruction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ZenithReconstructionWithKappa</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstruction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithPower</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithUncertainty</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VertexReconstruction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">PositionReconstruction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">TimeReconstruction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">InelasticityReconstruction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstructionWithKappa</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.AzimuthReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstruction</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DirectionReconstructionWithKappa</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.ZenithReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ZenithReconstruction</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ZenithReconstructionWithKappa</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.EnergyReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstruction</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithPower</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithUncertainty</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.VertexReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VertexReconstruction</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.PositionReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">PositionReconstruction</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.TimeReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">TimeReconstruction</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.InelasticityReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">InelasticityReconstruction</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -458,7 +919,132 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-task-reconstruction--page-root" class="md-nav__link">reconstruction</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstructionWithKappa</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstruction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DirectionReconstructionWithKappa</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ZenithReconstruction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ZenithReconstructionWithKappa</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstruction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithPower</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithUncertainty</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VertexReconstruction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">PositionReconstruction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">TimeReconstruction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">InelasticityReconstruction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -468,8 +1054,706 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="reconstruction">
-<h1 id="api-graphnet-models-task-reconstruction--page-root">reconstruction<a class="headerlink" href="#api-graphnet-models-task-reconstruction--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.task.reconstruction">
+<span id="reconstruction"></span><h1 id="api-graphnet-models-task-reconstruction--page-root">reconstruction<a class="headerlink" href="#api-graphnet-models-task-reconstruction--page-root" title="Link to this heading">¶</a></h1>
+<p>Reconstruction-specific <cite>Model</cite> class(es).</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">AzimuthReconstructionWithKappa</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#AzimuthReconstructionWithKappa"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p>
+<p>Reconstructs azimuthal angle and associated kappa (1/var).</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels">
+<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['azimuth']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels">
+<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['azimuth_pred',</span> <span class="pre">'azimuth_kappa']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs">
+<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">2</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstruction">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">AzimuthReconstruction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#AzimuthReconstruction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstruction" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa" title="graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa"><code class="xref py py-class docutils literal notranslate"><span class="pre">AzimuthReconstructionWithKappa</span></code></a></p>
+<p>Reconstructs azimuthal angle.</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels">
+<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['azimuth']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels">
+<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['azimuth_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs">
+<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">2</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.DirectionReconstructionWithKappa">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">DirectionReconstructionWithKappa</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#DirectionReconstructionWithKappa"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p>
+<p>Reconstructs direction with kappa from the 3D-vMF distribution.</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels">
+<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['direction']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels">
+<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['dir_x_pred',</span> <span class="pre">'dir_y_pred',</span> <span class="pre">'dir_z_pred',</span> <span class="pre">'direction_kappa']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs">
+<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">3</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstruction">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">ZenithReconstruction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#ZenithReconstruction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstruction" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p>
+<p>Reconstructs zenith angle.</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels">
+<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['zenith']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels">
+<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['zenith_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs">
+<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">1</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstructionWithKappa">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">ZenithReconstructionWithKappa</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#ZenithReconstructionWithKappa"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.models.task.reconstruction.ZenithReconstruction" title="graphnet.models.task.reconstruction.ZenithReconstruction"><code class="xref py py-class docutils literal notranslate"><span class="pre">ZenithReconstruction</span></code></a></p>
+<p>Reconstructs zenith angle and associated kappa (1/var).</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels">
+<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['zenith']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels">
+<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['zenith_pred',</span> <span class="pre">'zenith_kappa']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs">
+<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">2</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstruction">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">EnergyReconstruction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#EnergyReconstruction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstruction" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p>
+<p>Reconstructs energy using stable method.</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels">
+<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['energy']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels">
+<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['energy_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs">
+<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">1</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithPower">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">EnergyReconstructionWithPower</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#EnergyReconstructionWithPower"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p>
+<p>Reconstructs energy.</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels">
+<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['energy']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels">
+<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['energy_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs">
+<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">1</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">EnergyReconstructionWithUncertainty</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#EnergyReconstructionWithUncertainty"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.models.task.reconstruction.EnergyReconstruction" title="graphnet.models.task.reconstruction.EnergyReconstruction"><code class="xref py py-class docutils literal notranslate"><span class="pre">EnergyReconstruction</span></code></a></p>
+<p>Reconstructs energy and associated uncertainty (log(var)).</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels">
+<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['energy']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels">
+<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['energy_pred',</span> <span class="pre">'energy_sigma']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs">
+<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">2</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.VertexReconstruction">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">VertexReconstruction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#VertexReconstruction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.VertexReconstruction" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p>
+<p>Reconstructs vertex position and time.</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels">
+<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['vertex']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels">
+<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['position_x_pred',</span> <span class="pre">'position_y_pred',</span> <span class="pre">'position_z_pred',</span> <span class="pre">'interaction_time_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs">
+<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">4</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.PositionReconstruction">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">PositionReconstruction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#PositionReconstruction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.PositionReconstruction" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p>
+<p>Reconstructs vertex position.</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels">
+<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['position']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels">
+<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['position_x_pred',</span> <span class="pre">'position_y_pred',</span> <span class="pre">'position_z_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs">
+<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">3</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.TimeReconstruction">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">TimeReconstruction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#TimeReconstruction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.TimeReconstruction" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p>
+<p>Reconstructs time.</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels">
+<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['interaction_time']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels">
+<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['interaction_time_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs">
+<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">1</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.InelasticityReconstruction">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">InelasticityReconstruction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#InelasticityReconstruction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.InelasticityReconstruction" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p>
+<p>Reconstructs interaction inelasticity.</p>
+<p>That is, 1-(track energy / hadronic energy).</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels">
+<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['inelasticity']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels">
+<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['inelasticity_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs">
+<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">1</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+</dd></dl>
 </section>
 
 
@@ -519,7 +1803,7 @@ <h1 id="api-graphnet-models-task-reconstruction--page-root">reconstruction<a cla
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.task.task.html b/api/graphnet.models.task.task.html
index 1d38d3bae..35d5c451e 100644
--- a/api/graphnet.models.task.task.html
+++ b/api/graphnet.models.task.task.html
@@ -378,11 +378,128 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-task-task--page-root" class="md-nav__link">task</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Task</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.compute_loss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">compute_loss()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.inference" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">inference()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.train_eval" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">train_eval()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IdentityTask</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.task.Task" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Task</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.task.Task.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.task.Task.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.task.Task.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.task.Task.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.task.Task.compute_loss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">compute_loss()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.task.Task.inference" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">inference()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.task.Task.train_eval" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">train_eval()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.task.IdentityTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IdentityTask</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.task.IdentityTask.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.task.IdentityTask.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.task.task.IdentityTask.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li></ul>
     
     </li>
@@ -458,7 +575,40 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-task-task--page-root" class="md-nav__link">task</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Task</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.compute_loss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">compute_loss()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.inference" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">inference()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.train_eval" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">train_eval()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IdentityTask</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -468,8 +618,154 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="task">
-<h1 id="api-graphnet-models-task-task--page-root">task<a class="headerlink" href="#api-graphnet-models-task-task--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.task.task">
+<span id="task"></span><h1 id="api-graphnet-models-task-task--page-root">task<a class="headerlink" href="#api-graphnet-models-task-task--page-root" title="Link to this heading">¶</a></h1>
+<p>Base physics task-specific <cite>Model</cite> class(es).</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.task.Task">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.task.</span></span><span class="sig-name descname"><span class="pre">Task</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/task.html#Task"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.task.Task" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p>
+<p>Base class for all reconstruction and classification tasks.</p>
+<p>Construct <cite>Task</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this
+tasks, used to construct the affine transformation to the
+predicted quantity.</p></li>
+<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li>
+<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used
+to extract the  target tensor(s) from the <cite>Data</cite> object in
+<cite>.compute_loss(…)</cite>.</p></li>
+<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by
+the model during inference. If not given, the name will auto
+matically be set to <cite>target_label + _pred</cite>.</p></li>
+<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform
+both the predicted and target tensor before passing them to the
+loss function. Useful e.g. for having the model predict
+quantities on a physical scale, but transforming this scale to
+O(1) for a numerically stable loss computation.</p></li>
+<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target
+tensor before passing it, and the predicted tensor, to the loss
+function. Useful e.g. for having the model predict a
+transformed version of the target quantity, e.g. the log10-
+scaled energy, rather than the physical quantity itself. Used
+in conjunction with <cite>transform_inference</cite> to perform the
+inverse transform on the predicted quantity to recover the
+physical scale.</p></li>
+<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the
+model prediction to recover a physical scale. Used in
+conjunction with <cite>transform_target</cite>.</p></li>
+<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum
+of the range of validity for the inverse transforms
+<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is
+restricted. By default the invertibility of <cite>transform_target</cite>
+is tested on the range [-1e6, 1e6].</p></li>
+<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event
+loss weights.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.models.task.task.Task.nb_inputs">
+<em class="property"><span class="pre">abstract</span><span class="w"> </span><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#graphnet.models.task.task.Task.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd><p>Return number of inputs assumed by task.</p>
+</dd></dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.models.task.task.Task.default_target_labels">
+<em class="property"><span class="pre">abstract</span><span class="w"> </span><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#graphnet.models.task.task.Task.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd><p>Return default target labels.</p>
+</dd></dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.models.task.task.Task.default_prediction_labels">
+<em class="property"><span class="pre">abstract</span><span class="w"> </span><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#graphnet.models.task.task.Task.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd><p>Return default prediction labels.</p>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.task.task.Task.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/task.html#Task.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.task.Task.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Forward pass.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code>]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>x</strong> (<em>Tensor</em><em> | </em><em>Data</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.task.task.Task.compute_loss">
+<span class="sig-name descname"><span class="pre">compute_loss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pred</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/task.html#Task.compute_loss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.task.Task.compute_loss" title="Link to this definition">¶</a></dt>
+<dd><p>Compute loss of <cite>pred</cite> wrt.</p>
+<p>target labels in <cite>data</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>pred</strong> (<em>Tensor</em><em> | </em><em>Data</em>) – </p></li>
+<li><p><strong>data</strong> (<em>Data</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.task.task.Task.inference">
+<span class="sig-name descname"><span class="pre">inference</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/task.html#Task.inference"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.task.Task.inference" title="Link to this definition">¶</a></dt>
+<dd><p>Activate inference mode.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.models.task.task.Task.train_eval">
+<span class="sig-name descname"><span class="pre">train_eval</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/task.html#Task.train_eval"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.task.Task.train_eval" title="Link to this definition">¶</a></dt>
+<dd><p>Deactivate inference mode.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.models.task.task.IdentityTask">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.task.</span></span><span class="sig-name descname"><span class="pre">IdentityTask</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nb_outputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/task.html#IdentityTask"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.task.IdentityTask" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p>
+<p>Identity, or trivial, task.</p>
+<p>Construct IdentityTask.</p>
+<p>Return the <cite>nb_outputs</cite> as a direct, affine transformation of the last
+hidden layer.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>nb_outputs</strong> (<em>int</em>) – </p></li>
+<li><p><strong>target_labels</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>Any</em>) – </p></li>
+<li><p><strong>args</strong> (<em>Any</em>) – </p></li>
+<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.models.task.task.IdentityTask.default_target_labels">
+<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#graphnet.models.task.task.IdentityTask.default_target_labels" title="Link to this definition">¶</a></dt>
+<dd><p>Return default target labels.</p>
+</dd></dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.models.task.task.IdentityTask.default_prediction_labels">
+<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#graphnet.models.task.task.IdentityTask.default_prediction_labels" title="Link to this definition">¶</a></dt>
+<dd><p>Return default prediction labels.</p>
+</dd></dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.models.task.task.IdentityTask.nb_inputs">
+<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#graphnet.models.task.task.IdentityTask.nb_inputs" title="Link to this definition">¶</a></dt>
+<dd><p>Return number of inputs assumed by task.</p>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -519,7 +815,7 @@ <h1 id="api-graphnet-models-task-task--page-root">task<a class="headerlink" href
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.models.utils.html b/api/graphnet.models.utils.html
index 6fa20aae4..475f25794 100644
--- a/api/graphnet.models.utils.html
+++ b/api/graphnet.models.utils.html
@@ -386,11 +386,43 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-utils--page-root" class="md-nav__link">utils</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.utils.calculate_xyzt_homophily" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">calculate_xyzt_homophily()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.utils.calculate_distance_matrix" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">calculate_distance_matrix()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.utils.knn_graph_batch" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">knn_graph_batch()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.utils.calculate_xyzt_homophily" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">calculate_xyzt_homophily()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.utils.calculate_distance_matrix" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">calculate_distance_matrix()</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.models.utils.knn_graph_batch" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">knn_graph_batch()</span></code></a>
+      
+    
+    </li></ul>
+    
     </li></ul>
     
     </li>
@@ -436,7 +468,18 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-models-utils--page-root" class="md-nav__link">utils</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.models.utils.calculate_xyzt_homophily" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">calculate_xyzt_homophily()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.utils.calculate_distance_matrix" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">calculate_distance_matrix()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.models.utils.knn_graph_batch" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">knn_graph_batch()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -446,8 +489,69 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="utils">
-<h1 id="api-graphnet-models-utils--page-root">utils<a class="headerlink" href="#api-graphnet-models-utils--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.models.utils">
+<span id="utils"></span><h1 id="api-graphnet-models-utils--page-root">utils<a class="headerlink" href="#api-graphnet-models-utils--page-root" title="Link to this heading">¶</a></h1>
+<p>Utility functions for <cite>graphnet.models</cite>.</p>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.models.utils.calculate_xyzt_homophily">
+<span class="sig-prename descclassname"><span class="pre">graphnet.models.utils.</span></span><span class="sig-name descname"><span class="pre">calculate_xyzt_homophily</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">edge_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/utils.html#calculate_xyzt_homophily"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.utils.calculate_xyzt_homophily" title="Link to this definition">¶</a></dt>
+<dd><p>Calculate xyzt-homophily from a batch of graphs.</p>
+<p>Homophily is a graph scalar quantity that measures the likeness of
+variables in nodes. Notice that this calculator assumes a special order of
+input features in x.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>]</p>
+</dd>
+<dt class="field-even">Returns<span class="colon">:</span></dt>
+<dd class="field-even"><p>Tuple, each element with shape [batch_size,1].</p>
+</dd>
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>x</strong> (<em>Tensor</em>) – </p></li>
+<li><p><strong>edge_index</strong> (<em>LongTensor</em>) – </p></li>
+<li><p><strong>batch</strong> (<em>Batch</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.models.utils.calculate_distance_matrix">
+<span class="sig-prename descclassname"><span class="pre">graphnet.models.utils.</span></span><span class="sig-name descname"><span class="pre">calculate_distance_matrix</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">xyz_coords</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/utils.html#calculate_distance_matrix"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.utils.calculate_distance_matrix" title="Link to this definition">¶</a></dt>
+<dd><p>Calculate the matrix of pairwise distances between pulses.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>xyz_coords</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>) – (x,y,z)-coordinates of pulses, of shape [nb_doms, 3].</p>
+</dd>
+<dt class="field-even">Return type<span class="colon">:</span></dt>
+<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-odd">Returns<span class="colon">:</span></dt>
+<dd class="field-odd"><p>Matrix of pairwise distances, of shape [nb_doms, nb_doms]</p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.models.utils.knn_graph_batch">
+<span class="sig-prename descclassname"><span class="pre">graphnet.models.utils.</span></span><span class="sig-name descname"><span class="pre">knn_graph_batch</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">batch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">k</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/utils.html#knn_graph_batch"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.utils.knn_graph_batch" title="Link to this definition">¶</a></dt>
+<dd><p>Calculate k-nearest-neighbours with individual k for each batch event.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>batch</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Batch</span></code>) – Batch of events.</p></li>
+<li><p><strong>k</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]) – A list of k’s.</p></li>
+<li><p><strong>columns</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]) – The columns of Data.x used for computing the distances. E.g.,
+Data.x[:,[0,1,2]]</p></li>
+</ul>
+</dd>
+<dt class="field-even">Return type<span class="colon">:</span></dt>
+<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Batch</span></code></p>
+</dd>
+<dt class="field-odd">Returns<span class="colon">:</span></dt>
+<dd class="field-odd"><p>Returns the same batch of events, but with updated edges.</p>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -497,7 +601,7 @@ <h1 id="api-graphnet-models-utils--page-root">utils<a class="headerlink" href="#
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.pisa.fitting.html b/api/graphnet.pisa.fitting.html
index 71412f39f..8bc401462 100644
--- a/api/graphnet.pisa.fitting.html
+++ b/api/graphnet.pisa.fitting.html
@@ -659,7 +659,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.pisa.html b/api/graphnet.pisa.html
index 610103161..77788eee8 100644
--- a/api/graphnet.pisa.html
+++ b/api/graphnet.pisa.html
@@ -465,7 +465,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.pisa.plotting.html b/api/graphnet.pisa.plotting.html
index 07ab22cb9..fd463920f 100644
--- a/api/graphnet.pisa.plotting.html
+++ b/api/graphnet.pisa.plotting.html
@@ -573,7 +573,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.training.callbacks.html b/api/graphnet.training.callbacks.html
index 62f336aab..2d3ffa1f3 100644
--- a/api/graphnet.training.callbacks.html
+++ b/api/graphnet.training.callbacks.html
@@ -344,11 +344,110 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-training-callbacks--page-root" class="md-nav__link">callbacks</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.PiecewiseLinearLR" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">PiecewiseLinearLR</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.PiecewiseLinearLR.get_lr" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_lr()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ProgressBar</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_validation_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_validation_tqdm()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_predict_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_predict_tqdm()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_test_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_test_tqdm()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_train_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_train_tqdm()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.get_metrics" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_metrics()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_start" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">on_train_epoch_start()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_end" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">on_train_epoch_end()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.callbacks.PiecewiseLinearLR" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">PiecewiseLinearLR</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.callbacks.PiecewiseLinearLR.get_lr" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_lr()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.callbacks.ProgressBar" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ProgressBar</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.callbacks.ProgressBar.init_validation_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_validation_tqdm()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.callbacks.ProgressBar.init_predict_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_predict_tqdm()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.callbacks.ProgressBar.init_test_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_test_tqdm()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.callbacks.ProgressBar.init_train_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_train_tqdm()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.callbacks.ProgressBar.get_metrics" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_metrics()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_start" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">on_train_epoch_start()</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_end" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">on_train_epoch_end()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -408,7 +507,36 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-training-callbacks--page-root" class="md-nav__link">callbacks</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.PiecewiseLinearLR" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">PiecewiseLinearLR</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.PiecewiseLinearLR.get_lr" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_lr()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ProgressBar</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_validation_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_validation_tqdm()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_predict_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_predict_tqdm()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_test_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_test_tqdm()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_train_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_train_tqdm()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.get_metrics" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_metrics()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_start" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">on_train_epoch_start()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_end" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">on_train_epoch_end()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -418,8 +546,151 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="callbacks">
-<h1 id="api-graphnet-training-callbacks--page-root">callbacks<a class="headerlink" href="#api-graphnet-training-callbacks--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.training.callbacks">
+<span id="callbacks"></span><h1 id="api-graphnet-training-callbacks--page-root">callbacks<a class="headerlink" href="#api-graphnet-training-callbacks--page-root" title="Link to this heading">¶</a></h1>
+<p>Callback class(es) for using during model training.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.callbacks.PiecewiseLinearLR">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.callbacks.</span></span><span class="sig-name descname"><span class="pre">PiecewiseLinearLR</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">optimizer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">milestones</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">factors</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">last_epoch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">verbose</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#PiecewiseLinearLR"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.PiecewiseLinearLR" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">_LRScheduler</span></code></p>
+<p>Interpolate learning rate linearly between milestones.</p>
+<p>Construct <cite>PiecewiseLinearLR</cite>.</p>
+<p>For each milestone, denoting a specified number of steps, a factor
+multiplying the base learning rate is specified. For steps between two
+milestones, the learning rate is interpolated linearly between the two
+closest milestones. For steps before the first milestone, the factor
+for the first milestone is used; vice versa for steps after the last
+milestone.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>optimizer</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Optimizer</span></code>) – Wrapped optimizer.</p></li>
+<li><p><strong>milestones</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]) – List of step indices. Must be increasing.</p></li>
+<li><p><strong>factors</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>]) – List of multiplicative factors. Must be same length as
+<cite>milestones</cite>.</p></li>
+<li><p><strong>last_epoch</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">-1</span></code>) – The index of the last epoch.</p></li>
+<li><p><strong>verbose</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code>, default: <code class="docutils literal notranslate"><span class="pre">False</span></code>) – If <code class="docutils literal notranslate"><span class="pre">True</span></code>, prints a message to stdout for each update.</p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.training.callbacks.PiecewiseLinearLR.get_lr">
+<span class="sig-name descname"><span class="pre">get_lr</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#PiecewiseLinearLR.get_lr"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.PiecewiseLinearLR.get_lr" title="Link to this definition">¶</a></dt>
+<dd><p>Get effective learning rate(s) for each optimizer.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>]</p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.callbacks.</span></span><span class="sig-name descname"><span class="pre">ProgressBar</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">refresh_rate</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">process_position</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">TQDMProgressBar</span></code></p>
+<p>Custom progress bar for graphnet.</p>
+<p>Customises the default progress in pytorch-lightning.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>refresh_rate</strong> (<em>int</em>) – </p></li>
+<li><p><strong>process_position</strong> (<em>int</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar.init_validation_tqdm">
+<span class="sig-name descname"><span class="pre">init_validation_tqdm</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar.init_validation_tqdm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar.init_validation_tqdm" title="Link to this definition">¶</a></dt>
+<dd><p>Override for customisation.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Bar</span></code></p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar.init_predict_tqdm">
+<span class="sig-name descname"><span class="pre">init_predict_tqdm</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar.init_predict_tqdm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar.init_predict_tqdm" title="Link to this definition">¶</a></dt>
+<dd><p>Override for customisation.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Bar</span></code></p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar.init_test_tqdm">
+<span class="sig-name descname"><span class="pre">init_test_tqdm</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar.init_test_tqdm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar.init_test_tqdm" title="Link to this definition">¶</a></dt>
+<dd><p>Override for customisation.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Bar</span></code></p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar.init_train_tqdm">
+<span class="sig-name descname"><span class="pre">init_train_tqdm</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar.init_train_tqdm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar.init_train_tqdm" title="Link to this definition">¶</a></dt>
+<dd><p>Override for customisation.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Bar</span></code></p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar.get_metrics">
+<span class="sig-name descname"><span class="pre">get_metrics</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">trainer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar.get_metrics"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar.get_metrics" title="Link to this definition">¶</a></dt>
+<dd><p>Override to not show the version number in the logging.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>trainer</strong> (<em>Trainer</em>) – </p></li>
+<li><p><strong>model</strong> (<em>LightningModule</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar.on_train_epoch_start">
+<span class="sig-name descname"><span class="pre">on_train_epoch_start</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">trainer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar.on_train_epoch_start"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_start" title="Link to this definition">¶</a></dt>
+<dd><p>Print the results of the previous epoch on a separate line.</p>
+<p>This allows the user to see the losses/metrics for previous epochs
+while the current is training. The default behaviour in pytorch-
+lightning is to overwrite the progress bar from previous epochs.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>trainer</strong> (<em>Trainer</em>) – </p></li>
+<li><p><strong>model</strong> (<em>LightningModule</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar.on_train_epoch_end">
+<span class="sig-name descname"><span class="pre">on_train_epoch_end</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">trainer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar.on_train_epoch_end"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_end" title="Link to this definition">¶</a></dt>
+<dd><p>Log the final progress bar for the epoch to file.</p>
+<p>Don’t duplciate to stdout.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>trainer</strong> (<em>Trainer</em>) – </p></li>
+<li><p><strong>model</strong> (<em>LightningModule</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -469,7 +740,7 @@ <h1 id="api-graphnet-training-callbacks--page-root">callbacks<a class="headerlin
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.training.html b/api/graphnet.training.html
index 3a7e0dc99..c9ac2b308 100644
--- a/api/graphnet.training.html
+++ b/api/graphnet.training.html
@@ -424,10 +424,38 @@
 <p><h2> Submodules </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.training.callbacks.html">callbacks</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.training.labels.html">labels</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.training.loss_functions.html">loss_functions</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.training.utils.html">utils</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.training.callbacks.html">callbacks</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.callbacks.html#graphnet.training.callbacks.PiecewiseLinearLR"><code class="docutils literal notranslate"><span class="pre">PiecewiseLinearLR</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar"><code class="docutils literal notranslate"><span class="pre">ProgressBar</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.training.labels.html">labels</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.labels.html#graphnet.training.labels.Label"><code class="docutils literal notranslate"><span class="pre">Label</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.labels.html#graphnet.training.labels.Direction"><code class="docutils literal notranslate"><span class="pre">Direction</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.training.loss_functions.html">loss_functions</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction"><code class="docutils literal notranslate"><span class="pre">LossFunction</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.MSELoss"><code class="docutils literal notranslate"><span class="pre">MSELoss</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.RMSELoss"><code class="docutils literal notranslate"><span class="pre">RMSELoss</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCoshLoss"><code class="docutils literal notranslate"><span class="pre">LogCoshLoss</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.CrossEntropyLoss"><code class="docutils literal notranslate"><span class="pre">CrossEntropyLoss</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.BinaryCrossEntropyLoss"><code class="docutils literal notranslate"><span class="pre">BinaryCrossEntropyLoss</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCMK"><code class="docutils literal notranslate"><span class="pre">LogCMK</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss"><code class="docutils literal notranslate"><span class="pre">VonMisesFisherLoss</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisher2DLoss"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher2DLoss</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.EuclideanDistanceLoss"><code class="docutils literal notranslate"><span class="pre">EuclideanDistanceLoss</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisher3DLoss"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher3DLoss</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.training.utils.html">utils</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.utils.html#graphnet.training.utils.collate_fn"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.utils.html#graphnet.training.utils.make_dataloader"><code class="docutils literal notranslate"><span class="pre">make_dataloader()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.utils.html#graphnet.training.utils.make_train_validation_dataloader"><code class="docutils literal notranslate"><span class="pre">make_train_validation_dataloader()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.utils.html#graphnet.training.utils.get_predictions"><code class="docutils literal notranslate"><span class="pre">get_predictions()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.training.utils.html#graphnet.training.utils.save_results"><code class="docutils literal notranslate"><span class="pre">save_results()</span></code></a></li>
+</ul>
+</li>
 <li class="toctree-l1"><a class="reference internal" href="graphnet.training.weight_fitting.html">weight_fitting</a><ul>
 <li class="toctree-l2"><a class="reference internal" href="graphnet.training.weight_fitting.html#graphnet.training.weight_fitting.WeightFitter"><code class="docutils literal notranslate"><span class="pre">WeightFitter</span></code></a></li>
 <li class="toctree-l2"><a class="reference internal" href="graphnet.training.weight_fitting.html#graphnet.training.weight_fitting.Uniform"><code class="docutils literal notranslate"><span class="pre">Uniform</span></code></a></li>
@@ -485,7 +513,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.training.labels.html b/api/graphnet.training.labels.html
index cb464052f..6514d41a8 100644
--- a/api/graphnet.training.labels.html
+++ b/api/graphnet.training.labels.html
@@ -351,11 +351,45 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-training-labels--page-root" class="md-nav__link">labels</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.labels.Label" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Label</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.labels.Label.key" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">key</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.labels.Direction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Direction</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.labels.Label" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Label</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.labels.Label.key" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">key</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.labels.Direction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Direction</span></code></a>
       
     
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -408,7 +442,20 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-training-labels--page-root" class="md-nav__link">labels</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.labels.Label" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Label</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.labels.Label.key" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">key</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.labels.Direction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Direction</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -418,8 +465,48 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="labels">
-<h1 id="api-graphnet-training-labels--page-root">labels<a class="headerlink" href="#api-graphnet-training-labels--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.training.labels">
+<span id="labels"></span><h1 id="api-graphnet-training-labels--page-root">labels<a class="headerlink" href="#api-graphnet-training-labels--page-root" title="Link to this heading">¶</a></h1>
+<p>Class(es) for constructing training labels at runtime.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.labels.Label">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.labels.</span></span><span class="sig-name descname"><span class="pre">Label</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">key</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/labels.html#Label"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.labels.Label" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">ABC</span></code>, <a class="reference internal" href="graphnet.utilities.logging.html#graphnet.utilities.logging.Logger" title="graphnet.utilities.logging.Logger"><code class="xref py py-class docutils literal notranslate"><span class="pre">Logger</span></code></a></p>
+<p>Base <cite>Label</cite> class for producing labels from single <cite>Data</cite> instance.</p>
+<p>Construct <cite>Label</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>key</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – The name of the field in <cite>Data</cite> where the label will be
+stored. That is, <cite>graph[key] = label</cite>.</p>
+</dd>
+</dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.training.labels.Label.key">
+<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">key</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">str</span></em><a class="headerlink" href="#graphnet.training.labels.Label.key" title="Link to this definition">¶</a></dt>
+<dd><p>Return value of <cite>key</cite>.</p>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.labels.Direction">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.labels.</span></span><span class="sig-name descname"><span class="pre">Direction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">key</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">azimuth_key</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">zenith_key</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/labels.html#Direction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.labels.Direction" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.training.labels.Label" title="graphnet.training.labels.Label"><code class="xref py py-class docutils literal notranslate"><span class="pre">Label</span></code></a></p>
+<p>Class for producing particle direction/pointing label.</p>
+<p>Construct <cite>Direction</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>key</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'direction'</span></code>) – The name of the field in <cite>Data</cite> where the label will be
+stored. That is, <cite>graph[key] = label</cite>.</p></li>
+<li><p><strong>azimuth_key</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'azimuth'</span></code>) – The name of the pre-existing key in <cite>graph</cite> that will
+be used to access the azimiuth angle, used when calculating
+the direction.</p></li>
+<li><p><strong>zenith_key</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'zenith'</span></code>) – The name of the pre-existing key in <cite>graph</cite> that will
+be used to access the zenith angle, used when calculating the
+direction.</p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -469,7 +556,7 @@ <h1 id="api-graphnet-training-labels--page-root">labels<a class="headerlink" hre
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.training.loss_functions.html b/api/graphnet.training.loss_functions.html
index 61edd4160..feb15b1bf 100644
--- a/api/graphnet.training.loss_functions.html
+++ b/api/graphnet.training.loss_functions.html
@@ -358,11 +358,175 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-training-loss-functions--page-root" class="md-nav__link">loss_functions</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LossFunction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LossFunction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LossFunction.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.MSELoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">MSELoss</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.RMSELoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">RMSELoss</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCoshLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LogCoshLoss</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.CrossEntropyLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">CrossEntropyLoss</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.BinaryCrossEntropyLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryCrossEntropyLoss</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCMK" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LogCMK</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCMK.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCMK.backward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">backward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisherLoss</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk_exact()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk_approx()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisher2DLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher2DLoss</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.EuclideanDistanceLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EuclideanDistanceLoss</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisher3DLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher3DLoss</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.LossFunction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LossFunction</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.LossFunction.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.MSELoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">MSELoss</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.RMSELoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">RMSELoss</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.LogCoshLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LogCoshLoss</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.CrossEntropyLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">CrossEntropyLoss</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.BinaryCrossEntropyLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryCrossEntropyLoss</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.LogCMK" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LogCMK</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.LogCMK.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.LogCMK.backward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">backward()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.VonMisesFisherLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisherLoss</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk_exact()</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk_approx()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk()</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.VonMisesFisher2DLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher2DLoss</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.EuclideanDistanceLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EuclideanDistanceLoss</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.loss_functions.VonMisesFisher3DLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher3DLoss</span></code></a>
+      
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -408,7 +572,52 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-training-loss-functions--page-root" class="md-nav__link">loss_functions</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LossFunction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LossFunction</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LossFunction.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.MSELoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">MSELoss</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.RMSELoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">RMSELoss</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCoshLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LogCoshLoss</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.CrossEntropyLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">CrossEntropyLoss</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.BinaryCrossEntropyLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryCrossEntropyLoss</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCMK" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LogCMK</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCMK.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCMK.backward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">backward()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisherLoss</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk_exact()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk_approx()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisher2DLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher2DLoss</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.EuclideanDistanceLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EuclideanDistanceLoss</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisher3DLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher3DLoss</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -418,8 +627,281 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="loss-functions">
-<h1 id="api-graphnet-training-loss-functions--page-root">loss_functions<a class="headerlink" href="#api-graphnet-training-loss-functions--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.training.loss_functions">
+<span id="loss-functions"></span><h1 id="api-graphnet-training-loss-functions--page-root">loss_functions<a class="headerlink" href="#api-graphnet-training-loss-functions--page-root" title="Link to this heading">¶</a></h1>
+<p>Collection of loss functions.</p>
+<p>All loss functions inherit from <cite>LossFunction</cite> which ensures a common syntax,
+handles per-event weights, etc.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.LossFunction">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">LossFunction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#LossFunction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.LossFunction" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p>
+<p>Base class for loss functions in <cite>graphnet</cite>.</p>
+<p>Construct <cite>LossFunction</cite>, saving model config.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.LossFunction.forward">
+<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">prediction</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">weights</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">return_elements</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#LossFunction.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.LossFunction.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Forward pass for all loss functions.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>prediction</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>) – Tensor containing predictions. Shape [N,P]</p></li>
+<li><p><strong>target</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>) – Tensor containing targets. Shape [N,T]</p></li>
+<li><p><strong>return_elements</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code>, default: <code class="docutils literal notranslate"><span class="pre">False</span></code>) – Whether elementwise loss terms should be returned.
+The alternative is to return the averaged loss across examples.</p></li>
+<li><p><strong>weights</strong> (<em>Tensor</em><em> | </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+<dt class="field-even">Return type<span class="colon">:</span></dt>
+<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-odd">Returns<span class="colon">:</span></dt>
+<dd class="field-odd"><p>Loss, either averaged to a scalar (if <cite>return_elements = False</cite>) or
+elementwise terms with shape [N,] (if <cite>return_elements = True</cite>).</p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.MSELoss">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">MSELoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#MSELoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.MSELoss" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a></p>
+<p>Mean squared error loss.</p>
+<p>Construct <cite>LossFunction</cite>, saving model config.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.RMSELoss">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">RMSELoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#RMSELoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.RMSELoss" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.MSELoss" title="graphnet.training.loss_functions.MSELoss"><code class="xref py py-class docutils literal notranslate"><span class="pre">MSELoss</span></code></a></p>
+<p>Root mean squared error loss.</p>
+<p>Construct <cite>LossFunction</cite>, saving model config.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.LogCoshLoss">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">LogCoshLoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#LogCoshLoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.LogCoshLoss" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a></p>
+<p>Log-cosh loss function.</p>
+<p>Acts like x^2 for small x; and like <a href="#id1"><span class="problematic" id="id2">|x|</span></a> for large x.</p>
+<p>Construct <cite>LossFunction</cite>, saving model config.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.CrossEntropyLoss">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">CrossEntropyLoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">options</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#CrossEntropyLoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.CrossEntropyLoss" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a></p>
+<p>Compute cross-entropy loss for classification tasks.</p>
+<p>Predictions are an [N, num_class]-matrix of logits (i.e., non-softmax’ed
+probabilities), and targets are an [N,1]-matrix with integer values in
+(0, num_classes - 1).</p>
+<p>Construct CrossEntropyLoss.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>options</strong> (<em>int</em><em> | </em><em>List</em><em>[</em><em>Any</em><em>] </em><em>| </em><em>Dict</em><em>[</em><em>Any</em><em>, </em><em>int</em><em>]</em>) – </p></li>
+<li><p><strong>args</strong> (<em>Any</em>) – </p></li>
+<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.BinaryCrossEntropyLoss">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">BinaryCrossEntropyLoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#BinaryCrossEntropyLoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.BinaryCrossEntropyLoss" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a></p>
+<p>Compute binary cross entropy loss.</p>
+<p>Predictions are vector probabilities (i.e., values between 0 and 1), and
+targets should be 0 and 1.</p>
+<p>Construct <cite>LossFunction</cite>, saving model config.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.LogCMK">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">LogCMK</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#LogCMK"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.LogCMK" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">Function</span></code></p>
+<p>MIT License.</p>
+<p>Copyright (c) 2019 Max Ryabinin</p>
+<p>Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the “Software”), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:</p>
+<p>The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.</p>
+<p>THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+_____________________</p>
+<p>From [<a class="reference external" href="https://github.com/mryab/vmf_loss/blob/master/losses.py">https://github.com/mryab/vmf_loss/blob/master/losses.py</a>]
+Modified to use modified Bessel function instead of exponentially scaled ditto
+(i.e. <cite>.ive</cite> -&gt; <cite>.iv</cite>) as indiciated in [1812.04616] in spite of suggestion in
+Sec. 8.2 of this paper. The change has been validated through comparison with
+exact calculations for <cite>m=2</cite> and <cite>m=3</cite> and found to yield the correct results.</p>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.LogCMK.forward">
+<em class="property"><span class="pre">static</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">ctx</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">m</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kappa</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#LogCMK.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.LogCMK.forward" title="Link to this definition">¶</a></dt>
+<dd><p>Forward pass.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>ctx</strong> (<em>Any</em>) – </p></li>
+<li><p><strong>m</strong> (<em>int</em>) – </p></li>
+<li><p><strong>kappa</strong> (<em>Tensor</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.LogCMK.backward">
+<em class="property"><span class="pre">static</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">ctx</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">grad_output</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#LogCMK.backward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.LogCMK.backward" title="Link to this definition">¶</a></dt>
+<dd><p>Backward pass.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>ctx</strong> (<em>Any</em>) – </p></li>
+<li><p><strong>grad_output</strong> (<em>Tensor</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.VonMisesFisherLoss">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">VonMisesFisherLoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#VonMisesFisherLoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.VonMisesFisherLoss" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a></p>
+<p>General class for calculating von Mises-Fisher loss.</p>
+<p>Requires implementation for specific dimension <cite>m</cite> in which the target and
+prediction vectors need to be prepared.</p>
+<p>Construct <cite>LossFunction</cite>, saving model config.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p>
+</dd>
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact">
+<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">log_cmk_exact</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">m</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kappa</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#VonMisesFisherLoss.log_cmk_exact"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact" title="Link to this definition">¶</a></dt>
+<dd><p>Calculate $log C_{m}(k)$ term in von Mises-Fisher loss exactly.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>m</strong> (<em>int</em>) – </p></li>
+<li><p><strong>kappa</strong> (<em>Tensor</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx">
+<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">log_cmk_approx</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">m</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kappa</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#VonMisesFisherLoss.log_cmk_approx"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx" title="Link to this definition">¶</a></dt>
+<dd><p>Calculate $log C_{m}(k)$ term in von Mises-Fisher loss approx.</p>
+<p>[<a class="reference external" href="https://arxiv.org/abs/1812.04616">https://arxiv.org/abs/1812.04616</a>] Sec. 8.2 with additional minus sign.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>m</strong> (<em>int</em>) – </p></li>
+<li><p><strong>kappa</strong> (<em>Tensor</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk">
+<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">log_cmk</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">m</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kappa</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kappa_switch</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#VonMisesFisherLoss.log_cmk"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk" title="Link to this definition">¶</a></dt>
+<dd><p>Calculate $log C_{m}(k)$ term in von Mises-Fisher loss.</p>
+<p>Since <cite>log_cmk_exact</cite> is diverges for <cite>kappa</cite> &gt;~ 700 (using float64
+precision), and since <cite>log_cmk_approx</cite> is unaccurate for small <cite>kappa</cite>,
+this method automatically switches between the two at <cite>kappa_switch</cite>,
+ensuring continuity at this point.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>m</strong> (<em>int</em>) – </p></li>
+<li><p><strong>kappa</strong> (<em>Tensor</em>) – </p></li>
+<li><p><strong>kappa_switch</strong> (<em>float</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.VonMisesFisher2DLoss">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">VonMisesFisher2DLoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#VonMisesFisher2DLoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.VonMisesFisher2DLoss" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.VonMisesFisherLoss" title="graphnet.training.loss_functions.VonMisesFisherLoss"><code class="xref py py-class docutils literal notranslate"><span class="pre">VonMisesFisherLoss</span></code></a></p>
+<p>von Mises-Fisher loss function vectors in the 2D plane.</p>
+<p>Construct <cite>LossFunction</cite>, saving model config.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.EuclideanDistanceLoss">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">EuclideanDistanceLoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#EuclideanDistanceLoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.EuclideanDistanceLoss" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a></p>
+<p>Mean squared error in three dimensions.</p>
+<p>Construct <cite>LossFunction</cite>, saving model config.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.training.loss_functions.VonMisesFisher3DLoss">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">VonMisesFisher3DLoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#VonMisesFisher3DLoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.VonMisesFisher3DLoss" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.VonMisesFisherLoss" title="graphnet.training.loss_functions.VonMisesFisherLoss"><code class="xref py py-class docutils literal notranslate"><span class="pre">VonMisesFisherLoss</span></code></a></p>
+<p>von Mises-Fisher loss function vectors in the 3D plane.</p>
+<p>Construct <cite>LossFunction</cite>, saving model config.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -469,7 +951,7 @@ <h1 id="api-graphnet-training-loss-functions--page-root">loss_functions<a class=
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.training.utils.html b/api/graphnet.training.utils.html
index d29636c35..ac1d9919d 100644
--- a/api/graphnet.training.utils.html
+++ b/api/graphnet.training.utils.html
@@ -365,11 +365,61 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-training-utils--page-root" class="md-nav__link">utils</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.utils.collate_fn" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.utils.make_dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">make_dataloader()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.utils.make_train_validation_dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">make_train_validation_dataloader()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.utils.get_predictions" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_predictions()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.utils.save_results" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_results()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.utils.collate_fn" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.utils.make_dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">make_dataloader()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.utils.make_train_validation_dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">make_train_validation_dataloader()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.utils.get_predictions" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_predictions()</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.training.utils.save_results" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_results()</span></code></a>
+      
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -408,7 +458,22 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-training-utils--page-root" class="md-nav__link">utils</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.training.utils.collate_fn" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.utils.make_dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">make_dataloader()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.utils.make_train_validation_dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">make_train_validation_dataloader()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.utils.get_predictions" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_predictions()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.training.utils.save_results" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_results()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -418,8 +483,128 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="utils">
-<h1 id="api-graphnet-training-utils--page-root">utils<a class="headerlink" href="#api-graphnet-training-utils--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.training.utils">
+<span id="utils"></span><h1 id="api-graphnet-training-utils--page-root">utils<a class="headerlink" href="#api-graphnet-training-utils--page-root" title="Link to this heading">¶</a></h1>
+<p>Utility functions for <cite>graphnet.training</cite>.</p>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.training.utils.collate_fn">
+<span class="sig-prename descclassname"><span class="pre">graphnet.training.utils.</span></span><span class="sig-name descname"><span class="pre">collate_fn</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">graphs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/utils.html#collate_fn"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.utils.collate_fn" title="Link to this definition">¶</a></dt>
+<dd><p>Remove graphs with less than two DOM hits.</p>
+<p>Should not occur in “production.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Batch</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>graphs</strong> (<em>List</em><em>[</em><em>Data</em><em>]</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.training.utils.make_dataloader">
+<span class="sig-prename descclassname"><span class="pre">graphnet.training.utils.</span></span><span class="sig-name descname"><span class="pre">make_dataloader</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">db</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemaps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shuffle</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_workers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">persistent_workers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">string_selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">index_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">labels</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/utils.html#make_dataloader"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.utils.make_dataloader" title="Link to this definition">¶</a></dt>
+<dd><p>Construct <cite>DataLoader</cite> instance.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">DataLoader</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>db</strong> (<em>str</em>) – </p></li>
+<li><p><strong>pulsemaps</strong> (<em>str</em><em> | </em><em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>graph_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><em>GraphDefinition</em></a><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>features</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>truth</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>batch_size</strong> (<em>int</em>) – </p></li>
+<li><p><strong>shuffle</strong> (<em>bool</em>) – </p></li>
+<li><p><strong>selection</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+<li><p><strong>num_workers</strong> (<em>int</em>) – </p></li>
+<li><p><strong>persistent_workers</strong> (<em>bool</em>) – </p></li>
+<li><p><strong>node_truth</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+<li><p><strong>truth_table</strong> (<em>str</em>) – </p></li>
+<li><p><strong>node_truth_table</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>string_selection</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+<li><p><strong>loss_weight_table</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>loss_weight_column</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>index_column</strong> (<em>str</em>) – </p></li>
+<li><p><strong>labels</strong> (<em>Dict</em><em>[</em><em>str</em><em>, </em><em>Callable</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.training.utils.make_train_validation_dataloader">
+<span class="sig-prename descclassname"><span class="pre">graphnet.training.utils.</span></span><span class="sig-name descname"><span class="pre">make_train_validation_dataloader</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">db</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemaps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">database_indices</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">seed</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">test_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_workers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">persistent_workers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">string_selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">index_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">labels</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/utils.html#make_train_validation_dataloader"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.utils.make_train_validation_dataloader" title="Link to this definition">¶</a></dt>
+<dd><p>Construct train and test <cite>DataLoader</cite> instances.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">DataLoader</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">DataLoader</span></code>]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>db</strong> (<em>str</em>) – </p></li>
+<li><p><strong>graph_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><em>GraphDefinition</em></a><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>selection</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+<li><p><strong>pulsemaps</strong> (<em>str</em><em> | </em><em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>features</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>truth</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>batch_size</strong> (<em>int</em>) – </p></li>
+<li><p><strong>database_indices</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+<li><p><strong>seed</strong> (<em>int</em>) – </p></li>
+<li><p><strong>test_size</strong> (<em>float</em>) – </p></li>
+<li><p><strong>num_workers</strong> (<em>int</em>) – </p></li>
+<li><p><strong>persistent_workers</strong> (<em>bool</em>) – </p></li>
+<li><p><strong>node_truth</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>truth_table</strong> (<em>str</em>) – </p></li>
+<li><p><strong>node_truth_table</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>string_selection</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+<li><p><strong>loss_weight_column</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>loss_weight_table</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>index_column</strong> (<em>str</em>) – </p></li>
+<li><p><strong>labels</strong> (<em>Dict</em><em>[</em><em>str</em><em>, </em><em>Callable</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.training.utils.get_predictions">
+<span class="sig-prename descclassname"><span class="pre">graphnet.training.utils.</span></span><span class="sig-name descname"><span class="pre">get_predictions</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">trainer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dataloader</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_columns</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_level</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">additional_attributes</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/utils.html#get_predictions"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.utils.get_predictions" title="Link to this definition">¶</a></dt>
+<dd><p>Get <cite>model</cite> predictions on <cite>dataloader</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">DataFrame</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>trainer</strong> (<em>Trainer</em>) – </p></li>
+<li><p><strong>model</strong> (<a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><em>Model</em></a>) – </p></li>
+<li><p><strong>dataloader</strong> (<em>DataLoader</em>) – </p></li>
+<li><p><strong>prediction_columns</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>node_level</strong> (<em>bool</em>) – </p></li>
+<li><p><strong>additional_attributes</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.training.utils.save_results">
+<span class="sig-prename descclassname"><span class="pre">graphnet.training.utils.</span></span><span class="sig-name descname"><span class="pre">save_results</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">db</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tag</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">results</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">archive</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/utils.html#save_results"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.utils.save_results" title="Link to this definition">¶</a></dt>
+<dd><p>Save trained model and prediction <cite>results</cite> in <cite>db</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>db</strong> (<em>str</em>) – </p></li>
+<li><p><strong>tag</strong> (<em>str</em>) – </p></li>
+<li><p><strong>results</strong> (<em>DataFrame</em>) – </p></li>
+<li><p><strong>archive</strong> (<em>str</em>) – </p></li>
+<li><p><strong>model</strong> (<a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><em>Model</em></a>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -469,7 +654,7 @@ <h1 id="api-graphnet-training-utils--page-root">utils<a class="headerlink" href=
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.training.weight_fitting.html b/api/graphnet.training.weight_fitting.html
index aa7cedd58..78198cb33 100644
--- a/api/graphnet.training.weight_fitting.html
+++ b/api/graphnet.training.weight_fitting.html
@@ -611,7 +611,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.utilities.argparse.html b/api/graphnet.utilities.argparse.html
index 60f4ad179..814d2bdf2 100644
--- a/api/graphnet.utilities.argparse.html
+++ b/api/graphnet.utilities.argparse.html
@@ -639,7 +639,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.utilities.config.base_config.html b/api/graphnet.utilities.config.base_config.html
index dda866be3..48747c41d 100644
--- a/api/graphnet.utilities.config.base_config.html
+++ b/api/graphnet.utilities.config.base_config.html
@@ -357,11 +357,81 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-utilities-config-base-config--page-root" class="md-nav__link">base_config</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BaseConfig</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.load" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.dump" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">dump()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.get_all_argument_values" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_all_argument_values()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.base_config.BaseConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BaseConfig</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.base_config.BaseConfig.load" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.base_config.BaseConfig.dump" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">dump()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.base_config.BaseConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.base_config.BaseConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.base_config.BaseConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a>
       
     
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.base_config.get_all_argument_values" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_all_argument_values()</span></code></a>
+      
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -465,7 +535,28 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-utilities-config-base-config--page-root" class="md-nav__link">base_config</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BaseConfig</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.load" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.dump" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">dump()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.get_all_argument_values" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_all_argument_values()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -475,8 +566,88 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="base-config">
-<h1 id="api-graphnet-utilities-config-base-config--page-root">base_config<a class="headerlink" href="#api-graphnet-utilities-config-base-config--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.utilities.config.base_config">
+<span id="base-config"></span><h1 id="api-graphnet-utilities-config-base-config--page-root">base_config<a class="headerlink" href="#api-graphnet-utilities-config-base-config--page-root" title="Link to this heading">¶</a></h1>
+<p>Base config class(es).</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.utilities.config.base_config.BaseConfig">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.base_config.</span></span><span class="sig-name descname"><span class="pre">BaseConfig</span></span><a class="reference internal" href="../_modules/graphnet/utilities/config/base_config.html#BaseConfig"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.base_config.BaseConfig" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">BaseModel</span></code></p>
+<p>Base class for Configs.</p>
+<p>Create a new model by parsing and validating input data from keyword arguments.</p>
+<p>Raises [<cite>ValidationError</cite>][pydantic_core.ValidationError] if the input data cannot be
+validated to form a valid model.</p>
+<p><cite>__init__</cite> uses <cite>__pydantic_self__</cite> instead of the more common <cite>self</cite> for the first arg to
+allow <cite>self</cite> as a field name.</p>
+<dl class="field-list simple">
+</dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.utilities.config.base_config.BaseConfig.load">
+<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">load</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/base_config.html#BaseConfig.load"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.base_config.BaseConfig.load" title="Link to this definition">¶</a></dt>
+<dd><p>Load BaseConfig from <cite>path</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><a class="reference internal" href="#graphnet.utilities.config.base_config.BaseConfig" title="graphnet.utilities.config.base_config.BaseConfig"><code class="xref py py-class docutils literal notranslate"><span class="pre">BaseConfig</span></code></a></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>path</strong> (<em>str</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.utilities.config.base_config.BaseConfig.dump">
+<span class="sig-name descname"><span class="pre">dump</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/base_config.html#BaseConfig.dump"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.base_config.BaseConfig.dump" title="Link to this definition">¶</a></dt>
+<dd><p>Save BaseConfig to <cite>path</cite> as YAML file, or return as string.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>path</strong> (<em>str</em><em> | </em><em>None</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.utilities.config.base_config.BaseConfig.as_dict">
+<span class="sig-name descname"><span class="pre">as_dict</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/base_config.html#BaseConfig.as_dict"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.base_config.BaseConfig.as_dict" title="Link to this definition">¶</a></dt>
+<dd><p>Represent BaseConfig as a dict.</p>
+<p>This builds on <cite>BaseModel.dict()</cite> but can be overwritten.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>]]</p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.base_config.BaseConfig.model_config">
+<span class="sig-name descname"><span class="pre">model_config</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[ConfigDict]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{}</span></em><a class="headerlink" href="#graphnet.utilities.config.base_config.BaseConfig.model_config" title="Link to this definition">¶</a></dt>
+<dd><p>Configuration for the model, should be a dictionary conforming to [<cite>ConfigDict</cite>][pydantic.config.ConfigDict].</p>
+</dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.base_config.BaseConfig.model_fields">
+<span class="sig-name descname"><span class="pre">model_fields</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[dict[str,</span> <span class="pre">FieldInfo]]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{}</span></em><a class="headerlink" href="#graphnet.utilities.config.base_config.BaseConfig.model_fields" title="Link to this definition">¶</a></dt>
+<dd><p>Metadata about the fields defined on the model,
+mapping of field names to [<cite>FieldInfo</cite>][pydantic.fields.FieldInfo].</p>
+<p>This replaces <cite>Model.__fields__</cite> from Pydantic V1.</p>
+</dd></dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.utilities.config.base_config.get_all_argument_values">
+<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.base_config.</span></span><span class="sig-name descname"><span class="pre">get_all_argument_values</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">fn</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/base_config.html#get_all_argument_values"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.base_config.get_all_argument_values" title="Link to this definition">¶</a></dt>
+<dd><p>Return dict of all argument values to <cite>fn</cite>, including defaults.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>fn</strong> (<em>Callable</em>) – </p></li>
+<li><p><strong>args</strong> (<em>Any</em>) – </p></li>
+<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -526,7 +697,7 @@ <h1 id="api-graphnet-utilities-config-base-config--page-root">base_config<a clas
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.utilities.config.configurable.html b/api/graphnet.utilities.config.configurable.html
index cc82e52e3..7654534eb 100644
--- a/api/graphnet.utilities.config.configurable.html
+++ b/api/graphnet.utilities.config.configurable.html
@@ -364,11 +364,54 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-utilities-config-configurable--page-root" class="md-nav__link">configurable</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Configurable</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable.config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">config</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable.save_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_config()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.configurable.Configurable" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Configurable</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.configurable.Configurable.config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">config</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.configurable.Configurable.save_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_config()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.configurable.Configurable.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a>
       
     
+    </li></ul>
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -465,7 +508,22 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-utilities-config-configurable--page-root" class="md-nav__link">configurable</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Configurable</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable.config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">config</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable.save_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_config()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -475,8 +533,49 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="configurable">
-<h1 id="api-graphnet-utilities-config-configurable--page-root">configurable<a class="headerlink" href="#api-graphnet-utilities-config-configurable--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.utilities.config.configurable">
+<span id="configurable"></span><h1 id="api-graphnet-utilities-config-configurable--page-root">configurable<a class="headerlink" href="#api-graphnet-utilities-config-configurable--page-root" title="Link to this heading">¶</a></h1>
+<p>Bases for all configurable classes in  <cite>graphnet</cite>.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.utilities.config.configurable.Configurable">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.configurable.</span></span><span class="sig-name descname"><span class="pre">Configurable</span></span><a class="reference internal" href="../_modules/graphnet/utilities/config/configurable.html#Configurable"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.configurable.Configurable" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">ABC</span></code></p>
+<p>Base class for all configurable classes in graphnet.</p>
+<p>Construct <cite>Configurable</cite>.</p>
+<dl class="field-list simple">
+</dl>
+<dl class="py property">
+<dt class="sig sig-object py" id="graphnet.utilities.config.configurable.Configurable.config">
+<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">config</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference internal" href="graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig" title="graphnet.utilities.config.base_config.BaseConfig"><span class="pre">BaseConfig</span></a></em><a class="headerlink" href="#graphnet.utilities.config.configurable.Configurable.config" title="Link to this definition">¶</a></dt>
+<dd><p>Return configuration to re-create the instance.</p>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.utilities.config.configurable.Configurable.save_config">
+<span class="sig-name descname"><span class="pre">save_config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/configurable.html#Configurable.save_config"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.configurable.Configurable.save_config" title="Link to this definition">¶</a></dt>
+<dd><p>Save Config to <cite>path</cite> as YAML file.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>path</strong> (<em>str</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.utilities.config.configurable.Configurable.from_config">
+<em class="property"><span class="pre">abstract</span><span class="w"> </span><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">from_config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">source</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/configurable.html#Configurable.from_config"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.configurable.Configurable.from_config" title="Link to this definition">¶</a></dt>
+<dd><p>Construct instance from <cite>source</cite> configuration.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>source</strong> (<a class="reference internal" href="graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig" title="graphnet.utilities.config.base_config.BaseConfig"><em>BaseConfig</em></a><em> | </em><em>str</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -526,7 +625,7 @@ <h1 id="api-graphnet-utilities-config-configurable--page-root">configurable<a cl
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.utilities.config.dataset_config.html b/api/graphnet.utilities.config.dataset_config.html
index 27d11fb12..d3b202a27 100644
--- a/api/graphnet.utilities.config.dataset_config.html
+++ b/api/graphnet.utilities.config.dataset_config.html
@@ -371,11 +371,198 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-utilities-config-dataset-config--page-root" class="md-nav__link">dataset_config</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DatasetConfig</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.path" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">path</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">pulsemaps</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.features" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">features</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">node_truth</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.index_column" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">index_column</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth_table</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">node_truth_table</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.string_selection" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">string_selection</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.selection" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">selection</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_table</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_column</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_default_value</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.seed" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">seed</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">graph_definition</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.save_dataset_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_dataset_config()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DatasetConfig</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.path" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">path</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">pulsemaps</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.features" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">features</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">node_truth</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.index_column" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">index_column</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth_table</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">node_truth_table</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.string_selection" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">string_selection</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.selection" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">selection</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_table</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_column</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_default_value</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.seed" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">seed</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">graph_definition</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.dataset_config.save_dataset_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_dataset_config()</span></code></a>
+      
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -465,7 +652,54 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-utilities-config-dataset-config--page-root" class="md-nav__link">dataset_config</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DatasetConfig</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.path" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">path</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">pulsemaps</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.features" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">features</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">node_truth</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.index_column" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">index_column</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth_table</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">node_truth_table</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.string_selection" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">string_selection</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.selection" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">selection</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_table</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_column</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_default_value</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.seed" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">seed</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">graph_definition</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.save_dataset_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_dataset_config()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -475,8 +709,206 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="dataset-config">
-<h1 id="api-graphnet-utilities-config-dataset-config--page-root">dataset_config<a class="headerlink" href="#api-graphnet-utilities-config-dataset-config--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.utilities.config.dataset_config">
+<span id="dataset-config"></span><h1 id="api-graphnet-utilities-config-dataset-config--page-root">dataset_config<a class="headerlink" href="#api-graphnet-utilities-config-dataset-config--page-root" title="Link to this heading">¶</a></h1>
+<p>Config classes for the <cite>graphnet.data.dataset</cite> module.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.dataset_config.</span></span><span class="sig-name descname"><span class="pre">DatasetConfig</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">path</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemaps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">index_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">string_selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_default_value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">seed</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/dataset_config.html#DatasetConfig"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig" title="graphnet.utilities.config.base_config.BaseConfig"><code class="xref py py-class docutils literal notranslate"><span class="pre">BaseConfig</span></code></a></p>
+<p>Configuration for all <a href="#id1"><span class="problematic" id="id2">`</span></a>Dataset`s.</p>
+<p>Construct <cite>DataConfig</cite>.</p>
+<p>Can be used for dataset configuration as code, thereby making dataset
+construction more transparent and reproducible.</p>
+<p class="rubric">Examples</p>
+<p>In one session, do:</p>
+<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="n">dataset</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span><span class="o">...</span><span class="p">)</span>
+<span class="gp">&gt;&gt;&gt; </span><span class="n">dataset</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dump</span><span class="p">()</span>
+<span class="go">path: (...)</span>
+<span class="go">pulsemaps:</span>
+<span class="go">    - (...)</span>
+<span class="go">(...)</span>
+<span class="gp">&gt;&gt;&gt; </span><span class="n">dataset</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="s2">"dataset.yml"</span><span class="p">)</span>
+</pre></div>
+</div>
+<p>In another session, you can then do:
+&gt;&gt;&gt; dataset = Dataset.from_config(“dataset.yml”)</p>
+<p># Uniquely for <cite>DatasetConfig</cite>, you can also define and load
+# multiple datasets
+&gt;&gt;&gt; dataset.config.selection = {</p>
+<blockquote>
+<div><p>“train”: “event_no % 2 == 0”,
+“test”: “event_no % 2 == 1”,</p>
+</div></blockquote>
+<p>}
+&gt;&gt;&gt; dataset.config.dump(“dataset.yml”)
+&gt;&gt;&gt; datasets: Dict[str, Dataset] = Dataset.from_config(</p>
+<blockquote>
+<div><p>“dataset.yml”</p>
+</div></blockquote>
+<p>)
+&gt;&gt;&gt; datasets
+{</p>
+<blockquote>
+<div><p>“train”: Dataset(…),
+“test”: Dataset(…),</p>
+</div></blockquote>
+<p>}</p>
+<p># You can also combine multiple selections into a single, named
+# dataset
+&gt;&gt;&gt; dataset.config.selection = {</p>
+<blockquote>
+<div><dl class="simple">
+<dt>“train”: [</dt><dd><p>“event_no % 2 == 0 &amp; abs(pid) == 12”,
+“event_no % 2 == 0 &amp; abs(pid) == 14”,
+“event_no % 2 == 0 &amp; abs(pid) == 16”,</p>
+</dd>
+</dl>
+<p>],
+(…)</p>
+</div></blockquote>
+<p>}
+&gt;&gt;&gt; dataset.config.dump(“dataset.yml”)
+&gt;&gt;&gt; datasets: Dict[str, EnsembleDataset] = Dataset.from_config(</p>
+<blockquote>
+<div><p>“dataset.yml”</p>
+</div></blockquote>
+<p>)
+&gt;&gt;&gt; datasets
+{</p>
+<blockquote>
+<div><p>“train”: EnsembleDataset(…),
+(…)</p>
+</div></blockquote>
+<p>}</p>
+<p># Finally, you can still reference existing selection files in CSV
+# or JSON formats:
+&gt;&gt;&gt; dataset.config.selection = {</p>
+<blockquote>
+<div><p>“train”: “50000 random events ~ train_selection.csv”,
+“test”: “test_selection.csv”,</p>
+</div></blockquote>
+<p>}</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>path</strong> (<em>str</em><em> | </em><em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>pulsemaps</strong> (<em>str</em><em> | </em><em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>features</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>truth</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>node_truth</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+<li><p><strong>index_column</strong> (<em>str</em>) – </p></li>
+<li><p><strong>truth_table</strong> (<em>str</em>) – </p></li>
+<li><p><strong>node_truth_table</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>string_selection</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+<li><p><strong>selection</strong> (<em>str</em><em> | </em><em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>List</em><em>[</em><em>int</em><em> | </em><em>List</em><em>[</em><em>int</em><em>]</em><em>] </em><em>| </em><em>Dict</em><em>[</em><em>str</em><em>, </em><em>str</em><em> | </em><em>List</em><em>[</em><em>str</em><em>]</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+<li><p><strong>loss_weight_table</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>loss_weight_column</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>loss_weight_default_value</strong> (<em>float</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>seed</strong> (<em>int</em><em> | </em><em>None</em>) – </p></li>
+<li><p><strong>graph_definition</strong> (<em>Any</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.path">
+<span class="sig-name descname"><span class="pre">path</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.path" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps">
+<span class="sig-name descname"><span class="pre">pulsemaps</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.features">
+<span class="sig-name descname"><span class="pre">features</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.features" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.truth">
+<span class="sig-name descname"><span class="pre">truth</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.node_truth">
+<span class="sig-name descname"><span class="pre">node_truth</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.index_column">
+<span class="sig-name descname"><span class="pre">index_column</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.index_column" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.truth_table">
+<span class="sig-name descname"><span class="pre">truth_table</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth_table" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table">
+<span class="sig-name descname"><span class="pre">node_truth_table</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.string_selection">
+<span class="sig-name descname"><span class="pre">string_selection</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code><span class="pre">]]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.string_selection" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.selection">
+<span class="sig-name descname"><span class="pre">selection</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">],</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code><span class="pre">,</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code><span class="pre">]]],</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]]],</span> <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.selection" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table">
+<span class="sig-name descname"><span class="pre">loss_weight_table</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column">
+<span class="sig-name descname"><span class="pre">loss_weight_column</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value">
+<span class="sig-name descname"><span class="pre">loss_weight_default_value</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.seed">
+<span class="sig-name descname"><span class="pre">seed</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.seed" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition">
+<span class="sig-name descname"><span class="pre">graph_definition</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.as_dict">
+<span class="sig-name descname"><span class="pre">as_dict</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/dataset_config.html#DatasetConfig.as_dict"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.as_dict" title="Link to this definition">¶</a></dt>
+<dd><p>Represent ModelConfig as a dict.</p>
+<p>This builds on <cite>BaseModel.dict()</cite> but wraps the output in a single-key
+dictionary to make it unambiguous to identify model arguments that are
+themselves models.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>]]</p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.model_config">
+<span class="sig-name descname"><span class="pre">model_config</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[ConfigDict]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{}</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_config" title="Link to this definition">¶</a></dt>
+<dd><p>Configuration for the model, should be a dictionary conforming to [<cite>ConfigDict</cite>][pydantic.config.ConfigDict].</p>
+</dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.model_fields">
+<span class="sig-name descname"><span class="pre">model_fields</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[dict[str,</span> <span class="pre">FieldInfo]]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{'features':</span> <span class="pre">FieldInfo(annotation=List[str],</span> <span class="pre">required=True),</span> <span class="pre">'graph_definition':</span> <span class="pre">FieldInfo(annotation=Any,</span> <span class="pre">required=False),</span> <span class="pre">'index_column':</span> <span class="pre">FieldInfo(annotation=str,</span> <span class="pre">required=False,</span> <span class="pre">default='event_no'),</span> <span class="pre">'loss_weight_column':</span> <span class="pre">FieldInfo(annotation=Union[str,</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'loss_weight_default_value':</span> <span class="pre">FieldInfo(annotation=Union[float,</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'loss_weight_table':</span> <span class="pre">FieldInfo(annotation=Union[str,</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'node_truth':</span> <span class="pre">FieldInfo(annotation=Union[List[str],</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'node_truth_table':</span> <span class="pre">FieldInfo(annotation=Union[str,</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'path':</span> <span class="pre">FieldInfo(annotation=Union[str,</span> <span class="pre">List[str]],</span> <span class="pre">required=True),</span> <span class="pre">'pulsemaps':</span> <span class="pre">FieldInfo(annotation=Union[str,</span> <span class="pre">List[str]],</span> <span class="pre">required=True),</span> <span class="pre">'seed':</span> <span class="pre">FieldInfo(annotation=Union[int,</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'selection':</span> <span class="pre">FieldInfo(annotation=Union[str,</span> <span class="pre">List[str],</span> <span class="pre">List[Union[int,</span> <span class="pre">List[int]]],</span> <span class="pre">Dict[str,</span> <span class="pre">Union[str,</span> <span class="pre">List[str]]],</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'string_selection':</span> <span class="pre">FieldInfo(annotation=Union[List[int],</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'truth':</span> <span class="pre">FieldInfo(annotation=List[str],</span> <span class="pre">required=True),</span> <span class="pre">'truth_table':</span> <span class="pre">FieldInfo(annotation=str,</span> <span class="pre">required=False,</span> <span class="pre">default='truth')}</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_fields" title="Link to this definition">¶</a></dt>
+<dd><p>Metadata about the fields defined on the model,
+mapping of field names to [<cite>FieldInfo</cite>][pydantic.fields.FieldInfo].</p>
+<p>This replaces <cite>Model.__fields__</cite> from Pydantic V1.</p>
+</dd></dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.save_dataset_config">
+<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.dataset_config.</span></span><span class="sig-name descname"><span class="pre">save_dataset_config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">init_fn</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/dataset_config.html#save_dataset_config"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.dataset_config.save_dataset_config" title="Link to this definition">¶</a></dt>
+<dd><p>Save the arguments to <cite>__init__</cite> functions as member <cite>DatasetConfig</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>init_fn</strong> (<em>Callable</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -526,7 +958,7 @@ <h1 id="api-graphnet-utilities-config-dataset-config--page-root">dataset_config<
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.utilities.config.html b/api/graphnet.utilities.config.html
index bba4a6957..59a759d23 100644
--- a/api/graphnet.utilities.config.html
+++ b/api/graphnet.utilities.config.html
@@ -474,17 +474,44 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="config">
-<h1 id="api-graphnet-utilities-config--page-root">config<a class="headerlink" href="#api-graphnet-utilities-config--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.utilities.config">
+<span id="config"></span><h1 id="api-graphnet-utilities-config--page-root">config<a class="headerlink" href="#api-graphnet-utilities-config--page-root" title="Link to this heading">¶</a></h1>
+<p>Modules for configuration files for use across <cite>graphnet</cite>.</p>
 <p><h2> Submodules </h2></p>
 <div class="toctree-wrapper compound">
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.base_config.html">base_config</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.configurable.html">configurable</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.dataset_config.html">dataset_config</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.model_config.html">model_config</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.parsing.html">parsing</a></li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.training_config.html">training_config</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.base_config.html">base_config</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig"><code class="docutils literal notranslate"><span class="pre">BaseConfig</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.get_all_argument_values"><code class="docutils literal notranslate"><span class="pre">get_all_argument_values()</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.configurable.html">configurable</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable"><code class="docutils literal notranslate"><span class="pre">Configurable</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.dataset_config.html">dataset_config</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig"><code class="docutils literal notranslate"><span class="pre">DatasetConfig</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.save_dataset_config"><code class="docutils literal notranslate"><span class="pre">save_dataset_config()</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.model_config.html">model_config</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig"><code class="docutils literal notranslate"><span class="pre">ModelConfig</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.save_model_config"><code class="docutils literal notranslate"><span class="pre">save_model_config()</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.parsing.html">parsing</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.traverse_and_apply"><code class="docutils literal notranslate"><span class="pre">traverse_and_apply()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.list_all_submodules"><code class="docutils literal notranslate"><span class="pre">list_all_submodules()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.get_all_grapnet_classes"><code class="docutils literal notranslate"><span class="pre">get_all_grapnet_classes()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.is_graphnet_module"><code class="docutils literal notranslate"><span class="pre">is_graphnet_module()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.is_graphnet_class"><code class="docutils literal notranslate"><span class="pre">is_graphnet_class()</span></code></a></li>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.get_graphnet_classes"><code class="docutils literal notranslate"><span class="pre">get_graphnet_classes()</span></code></a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.training_config.html">training_config</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig"><code class="docutils literal notranslate"><span class="pre">TrainingConfig</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -536,7 +563,7 @@ <h1 id="api-graphnet-utilities-config--page-root">config<a class="headerlink" hr
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.utilities.config.model_config.html b/api/graphnet.utilities.config.model_config.html
index 5c6fa0fbb..52531fc50 100644
--- a/api/graphnet.utilities.config.model_config.html
+++ b/api/graphnet.utilities.config.model_config.html
@@ -378,11 +378,81 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-utilities-config-model-config--page-root" class="md-nav__link">model_config</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ModelConfig</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.class_name" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">class_name</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.arguments" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">arguments</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.save_model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_model_config()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.model_config.ModelConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ModelConfig</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.model_config.ModelConfig.class_name" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">class_name</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.model_config.ModelConfig.arguments" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">arguments</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.model_config.ModelConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.model_config.ModelConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.model_config.ModelConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.model_config.save_model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_model_config()</span></code></a>
+      
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -465,7 +535,28 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-utilities-config-model-config--page-root" class="md-nav__link">model_config</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ModelConfig</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.class_name" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">class_name</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.arguments" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">arguments</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.save_model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_model_config()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -475,8 +566,88 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="model-config">
-<h1 id="api-graphnet-utilities-config-model-config--page-root">model_config<a class="headerlink" href="#api-graphnet-utilities-config-model-config--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.utilities.config.model_config">
+<span id="model-config"></span><h1 id="api-graphnet-utilities-config-model-config--page-root">model_config<a class="headerlink" href="#api-graphnet-utilities-config-model-config--page-root" title="Link to this heading">¶</a></h1>
+<p>Config classes for the <cite>graphnet.models</cite> module.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.utilities.config.model_config.ModelConfig">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.model_config.</span></span><span class="sig-name descname"><span class="pre">ModelConfig</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">class_name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">arguments</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/model_config.html#ModelConfig"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.model_config.ModelConfig" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig" title="graphnet.utilities.config.base_config.BaseConfig"><code class="xref py py-class docutils literal notranslate"><span class="pre">BaseConfig</span></code></a></p>
+<p>Configuration for all <a href="#id1"><span class="problematic" id="id2">`</span></a>Model`s.</p>
+<p>Construct <cite>ModelConfig</cite>.</p>
+<p>Can be used for model configuration as code, thereby making model
+construction more transparent and reproducible. Note that this does
+<em>not</em> save any trainable weights, meaning this is only a configuration
+for the model’s hyperparameters. Any model instantiated from a
+ModelConfig or file will be randomly initialised, and thus should be
+trained.</p>
+<p class="rubric">Examples</p>
+<p>In one session, do:</p>
+<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">(</span><span class="o">...</span><span class="p">)</span>
+<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dump</span><span class="p">()</span>
+<span class="go">arguments:</span>
+<span class="go">    - (...): (...)</span>
+<span class="go">class_name: Model</span>
+<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="s2">"model.yml"</span><span class="p">)</span>
+</pre></div>
+</div>
+<p>In another session, you can then do:
+&gt;&gt;&gt; model = Model.from_config(“model.yml”)</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>class_name</strong> (<em>str</em>) – </p></li>
+<li><p><strong>arguments</strong> (<em>Dict</em><em>[</em><em>str</em><em>, </em><em>Any</em><em>]</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.model_config.ModelConfig.class_name">
+<span class="sig-name descname"><span class="pre">class_name</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code></em><a class="headerlink" href="#graphnet.utilities.config.model_config.ModelConfig.class_name" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.model_config.ModelConfig.arguments">
+<span class="sig-name descname"><span class="pre">arguments</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.model_config.ModelConfig.arguments" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py method">
+<dt class="sig sig-object py" id="graphnet.utilities.config.model_config.ModelConfig.as_dict">
+<span class="sig-name descname"><span class="pre">as_dict</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/model_config.html#ModelConfig.as_dict"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.model_config.ModelConfig.as_dict" title="Link to this definition">¶</a></dt>
+<dd><p>Represent ModelConfig as a dict.</p>
+<p>This builds on <cite>BaseModel.dict()</cite> but wraps the output in a single-key
+dictionary to make it unambiguous to identify model arguments that are
+themselves models.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>]]</p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.model_config.ModelConfig.model_config">
+<span class="sig-name descname"><span class="pre">model_config</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[ConfigDict]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{}</span></em><a class="headerlink" href="#graphnet.utilities.config.model_config.ModelConfig.model_config" title="Link to this definition">¶</a></dt>
+<dd><p>Configuration for the model, should be a dictionary conforming to [<cite>ConfigDict</cite>][pydantic.config.ConfigDict].</p>
+</dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.model_config.ModelConfig.model_fields">
+<span class="sig-name descname"><span class="pre">model_fields</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[dict[str,</span> <span class="pre">FieldInfo]]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{'arguments':</span> <span class="pre">FieldInfo(annotation=Dict[str,</span> <span class="pre">Any],</span> <span class="pre">required=True),</span> <span class="pre">'class_name':</span> <span class="pre">FieldInfo(annotation=str,</span> <span class="pre">required=True)}</span></em><a class="headerlink" href="#graphnet.utilities.config.model_config.ModelConfig.model_fields" title="Link to this definition">¶</a></dt>
+<dd><p>Metadata about the fields defined on the model,
+mapping of field names to [<cite>FieldInfo</cite>][pydantic.fields.FieldInfo].</p>
+<p>This replaces <cite>Model.__fields__</cite> from Pydantic V1.</p>
+</dd></dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.utilities.config.model_config.save_model_config">
+<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.model_config.</span></span><span class="sig-name descname"><span class="pre">save_model_config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">init_fn</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/model_config.html#save_model_config"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.model_config.save_model_config" title="Link to this definition">¶</a></dt>
+<dd><p>Save the arguments to <cite>__init__</cite> functions as a member <cite>ModelConfig</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>init_fn</strong> (<em>Callable</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -526,7 +697,7 @@ <h1 id="api-graphnet-utilities-config-model-config--page-root">model_config<a cl
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.utilities.config.parsing.html b/api/graphnet.utilities.config.parsing.html
index 210d5f809..0e932473f 100644
--- a/api/graphnet.utilities.config.parsing.html
+++ b/api/graphnet.utilities.config.parsing.html
@@ -385,11 +385,70 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-utilities-config-parsing--page-root" class="md-nav__link">parsing</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.traverse_and_apply" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">traverse_and_apply()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.list_all_submodules" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">list_all_submodules()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.get_all_grapnet_classes" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_all_grapnet_classes()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.is_graphnet_module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">is_graphnet_module()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.is_graphnet_class" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">is_graphnet_class()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.get_graphnet_classes" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_graphnet_classes()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.parsing.traverse_and_apply" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">traverse_and_apply()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.parsing.list_all_submodules" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">list_all_submodules()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.parsing.get_all_grapnet_classes" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_all_grapnet_classes()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.parsing.is_graphnet_module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">is_graphnet_module()</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.parsing.is_graphnet_class" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">is_graphnet_class()</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.parsing.get_graphnet_classes" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_graphnet_classes()</span></code></a>
+      
+    
+    </li></ul>
+    
     </li>
     <li class="md-nav__item">
     
@@ -465,7 +524,24 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-utilities-config-parsing--page-root" class="md-nav__link">parsing</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.traverse_and_apply" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">traverse_and_apply()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.list_all_submodules" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">list_all_submodules()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.get_all_grapnet_classes" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_all_grapnet_classes()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.is_graphnet_module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">is_graphnet_module()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.is_graphnet_class" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">is_graphnet_class()</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.get_graphnet_classes" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_graphnet_classes()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -475,8 +551,91 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="parsing">
-<h1 id="api-graphnet-utilities-config-parsing--page-root">parsing<a class="headerlink" href="#api-graphnet-utilities-config-parsing--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.utilities.config.parsing">
+<span id="parsing"></span><h1 id="api-graphnet-utilities-config-parsing--page-root">parsing<a class="headerlink" href="#api-graphnet-utilities-config-parsing--page-root" title="Link to this heading">¶</a></h1>
+<p>Utility functions for parsing for using with Config-classes.</p>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.utilities.config.parsing.traverse_and_apply">
+<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.parsing.</span></span><span class="sig-name descname"><span class="pre">traverse_and_apply</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">obj</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">fn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">fn_kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/parsing.html#traverse_and_apply"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.parsing.traverse_and_apply" title="Link to this definition">¶</a></dt>
+<dd><p>Apply <cite>fn</cite> to all elements in <cite>obj</cite>, resulting in same structure.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><ul class="simple">
+<li><p><strong>obj</strong> (<em>Any</em>) – </p></li>
+<li><p><strong>fn</strong> (<em>Callable</em>) – </p></li>
+<li><p><strong>fn_kwargs</strong> (<em>Dict</em><em>[</em><em>str</em><em>, </em><em>Any</em><em>] </em><em>| </em><em>None</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.utilities.config.parsing.list_all_submodules">
+<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.parsing.</span></span><span class="sig-name descname"><span class="pre">list_all_submodules</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">packages</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/parsing.html#list_all_submodules"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.parsing.list_all_submodules" title="Link to this definition">¶</a></dt>
+<dd><p>List all submodules in <cite>packages</cite> recursively.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">ModuleType</span></code>]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>packages</strong> (<em>module</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.utilities.config.parsing.get_all_grapnet_classes">
+<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.parsing.</span></span><span class="sig-name descname"><span class="pre">get_all_grapnet_classes</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">packages</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/parsing.html#get_all_grapnet_classes"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.parsing.get_all_grapnet_classes" title="Link to this definition">¶</a></dt>
+<dd><p>List all grapnet classes in <cite>packages</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">type</span></code>]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>packages</strong> (<em>module</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.utilities.config.parsing.is_graphnet_module">
+<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.parsing.</span></span><span class="sig-name descname"><span class="pre">is_graphnet_module</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">obj</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/parsing.html#is_graphnet_module"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.parsing.is_graphnet_module" title="Link to this definition">¶</a></dt>
+<dd><p>Return whether <cite>obj</cite> is a module in graphnet.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>obj</strong> (<em>module</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.utilities.config.parsing.is_graphnet_class">
+<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.parsing.</span></span><span class="sig-name descname"><span class="pre">is_graphnet_class</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">obj</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/parsing.html#is_graphnet_class"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.parsing.is_graphnet_class" title="Link to this definition">¶</a></dt>
+<dd><p>Return whether <cite>obj</cite> is a class in graphnet.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>obj</strong> (<em>type</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.utilities.config.parsing.get_graphnet_classes">
+<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.parsing.</span></span><span class="sig-name descname"><span class="pre">get_graphnet_classes</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">module</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/parsing.html#get_graphnet_classes"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.parsing.get_graphnet_classes" title="Link to this definition">¶</a></dt>
+<dd><p>Return a lookup of all graphnet class names in <cite>module</cite>.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">type</span></code>]</p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>module</strong> (<em>module</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -526,7 +685,7 @@ <h1 id="api-graphnet-utilities-config-parsing--page-root">parsing<a class="heade
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.utilities.config.training_config.html b/api/graphnet.utilities.config.training_config.html
index bdb49dee2..26774e9f2 100644
--- a/api/graphnet.utilities.config.training_config.html
+++ b/api/graphnet.utilities.config.training_config.html
@@ -392,11 +392,81 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-utilities-config-training-config--page-root" class="md-nav__link">training_config</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">TrainingConfig</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.target" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">target</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">early_stopping_patience</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.fit" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">fit</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">dataloader</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.training_config.TrainingConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">TrainingConfig</span></code></a>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.training_config.TrainingConfig.target" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">target</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">early_stopping_patience</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.training_config.TrainingConfig.fit" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">fit</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.training_config.TrainingConfig.dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">dataloader</span></code></a>
+      
+    
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.training_config.TrainingConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a>
       
     
+    </li>
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.config.training_config.TrainingConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a>
+      
+    
+    </li></ul>
+    
+    </li></ul>
+    
     </li></ul>
     
     </li>
@@ -465,7 +535,28 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-utilities-config-training-config--page-root" class="md-nav__link">training_config</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">TrainingConfig</span></code></a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.target" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">target</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">early_stopping_patience</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.fit" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">fit</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">dataloader</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a>
+        </li>
+        <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a>
+        </li></ul>
+            </nav>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -475,8 +566,58 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="training-config">
-<h1 id="api-graphnet-utilities-config-training-config--page-root">training_config<a class="headerlink" href="#api-graphnet-utilities-config-training-config--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.utilities.config.training_config">
+<span id="training-config"></span><h1 id="api-graphnet-utilities-config-training-config--page-root">training_config<a class="headerlink" href="#api-graphnet-utilities-config-training-config--page-root" title="Link to this heading">¶</a></h1>
+<p>Config classes for the <cite>graphnet.training</cite> module.</p>
+<dl class="py class">
+<dt class="sig sig-object py" id="graphnet.utilities.config.training_config.TrainingConfig">
+<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.training_config.</span></span><span class="sig-name descname"><span class="pre">TrainingConfig</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">early_stopping_patience</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">fit</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dataloader</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/training_config.html#TrainingConfig"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.training_config.TrainingConfig" title="Link to this definition">¶</a></dt>
+<dd><p>Bases: <a class="reference internal" href="graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig" title="graphnet.utilities.config.base_config.BaseConfig"><code class="xref py py-class docutils literal notranslate"><span class="pre">BaseConfig</span></code></a></p>
+<p>Configuration for all trainings.</p>
+<p>Create a new model by parsing and validating input data from keyword arguments.</p>
+<p>Raises [<cite>ValidationError</cite>][pydantic_core.ValidationError] if the input data cannot be
+validated to form a valid model.</p>
+<p><cite>__init__</cite> uses <cite>__pydantic_self__</cite> instead of the more common <cite>self</cite> for the first arg to
+allow <cite>self</cite> as a field name.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Parameters<span class="colon">:</span></dt>
+<dd class="field-odd"><ul class="simple">
+<li><p><strong>target</strong> (<em>str</em><em> | </em><em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li>
+<li><p><strong>early_stopping_patience</strong> (<em>int</em>) – </p></li>
+<li><p><strong>fit</strong> (<em>Dict</em><em>[</em><em>str</em><em>, </em><em>Any</em><em>]</em>) – </p></li>
+<li><p><strong>dataloader</strong> (<em>Dict</em><em>[</em><em>str</em><em>, </em><em>Any</em><em>]</em>) – </p></li>
+</ul>
+</dd>
+</dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.training_config.TrainingConfig.target">
+<span class="sig-name descname"><span class="pre">target</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]]</span></em><a class="headerlink" href="#graphnet.utilities.config.training_config.TrainingConfig.target" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience">
+<span class="sig-name descname"><span class="pre">early_stopping_patience</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code></em><a class="headerlink" href="#graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.training_config.TrainingConfig.fit">
+<span class="sig-name descname"><span class="pre">fit</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.training_config.TrainingConfig.fit" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.training_config.TrainingConfig.dataloader">
+<span class="sig-name descname"><span class="pre">dataloader</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.training_config.TrainingConfig.dataloader" title="Link to this definition">¶</a></dt>
+<dd></dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.training_config.TrainingConfig.model_config">
+<span class="sig-name descname"><span class="pre">model_config</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[ConfigDict]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{}</span></em><a class="headerlink" href="#graphnet.utilities.config.training_config.TrainingConfig.model_config" title="Link to this definition">¶</a></dt>
+<dd><p>Configuration for the model, should be a dictionary conforming to [<cite>ConfigDict</cite>][pydantic.config.ConfigDict].</p>
+</dd></dl>
+<dl class="py attribute">
+<dt class="sig sig-object py" id="graphnet.utilities.config.training_config.TrainingConfig.model_fields">
+<span class="sig-name descname"><span class="pre">model_fields</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[dict[str,</span> <span class="pre">FieldInfo]]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{'dataloader':</span> <span class="pre">FieldInfo(annotation=Dict[str,</span> <span class="pre">Any],</span> <span class="pre">required=True),</span> <span class="pre">'early_stopping_patience':</span> <span class="pre">FieldInfo(annotation=int,</span> <span class="pre">required=True),</span> <span class="pre">'fit':</span> <span class="pre">FieldInfo(annotation=Dict[str,</span> <span class="pre">Any],</span> <span class="pre">required=True),</span> <span class="pre">'target':</span> <span class="pre">FieldInfo(annotation=Union[str,</span> <span class="pre">List[str]],</span> <span class="pre">required=True)}</span></em><a class="headerlink" href="#graphnet.utilities.config.training_config.TrainingConfig.model_fields" title="Link to this definition">¶</a></dt>
+<dd><p>Metadata about the fields defined on the model,
+mapping of field names to [<cite>FieldInfo</cite>][pydantic.fields.FieldInfo].</p>
+<p>This replaces <cite>Model.__fields__</cite> from Pydantic V1.</p>
+</dd></dl>
+</dd></dl>
 </section>
 
 
@@ -526,7 +667,7 @@ <h1 id="api-graphnet-utilities-config-training-config--page-root">training_confi
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.utilities.decorators.html b/api/graphnet.utilities.decorators.html
index 0203d3c40..92ae0d5b7 100644
--- a/api/graphnet.utilities.decorators.html
+++ b/api/graphnet.utilities.decorators.html
@@ -484,7 +484,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.utilities.filesys.html b/api/graphnet.utilities.filesys.html
index 8e0c99feb..f13e120c5 100644
--- a/api/graphnet.utilities.filesys.html
+++ b/api/graphnet.utilities.filesys.html
@@ -604,7 +604,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.utilities.html b/api/graphnet.utilities.html
index 847defb5c..5723c23e5 100644
--- a/api/graphnet.utilities.html
+++ b/api/graphnet.utilities.html
@@ -476,7 +476,10 @@
 <li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.logging.html#graphnet.utilities.logging.Logger"><code class="docutils literal notranslate"><span class="pre">Logger</span></code></a></li>
 </ul>
 </li>
-<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.maths.html">maths</a></li>
+<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.maths.html">maths</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.maths.html#graphnet.utilities.maths.eps_like"><code class="docutils literal notranslate"><span class="pre">eps_like()</span></code></a></li>
+</ul>
+</li>
 </ul>
 </div>
 </section>
@@ -528,7 +531,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.utilities.imports.html b/api/graphnet.utilities.imports.html
index 7166dfe02..7ec9aa916 100644
--- a/api/graphnet.utilities.imports.html
+++ b/api/graphnet.utilities.imports.html
@@ -581,7 +581,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.utilities.logging.html b/api/graphnet.utilities.logging.html
index a16ad7888..7972b4bad 100644
--- a/api/graphnet.utilities.logging.html
+++ b/api/graphnet.utilities.logging.html
@@ -831,7 +831,7 @@
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/graphnet.utilities.maths.html b/api/graphnet.utilities.maths.html
index 9fc848fe1..8eb8048fa 100644
--- a/api/graphnet.utilities.maths.html
+++ b/api/graphnet.utilities.maths.html
@@ -393,11 +393,25 @@
       
         
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-utilities-maths--page-root" class="md-nav__link">maths</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.maths.eps_like" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">eps_like()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
+      <ul class="md-nav__list"> 
+    <li class="md-nav__item">
+    
+    
+      <a href="#graphnet.utilities.maths.eps_like" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">eps_like()</span></code></a>
       
     
+    </li></ul>
+    
     </li></ul>
     
     </li>
@@ -422,7 +436,14 @@
               <div class="md-sidebar__inner">
                 
 <nav class="md-nav md-nav--secondary">
+    <label class="md-nav__title" for="__toc">"Contents"</label>
   <ul class="md-nav__list" data-md-scrollfix="">
+        <li class="md-nav__item"><a href="#api-graphnet-utilities-maths--page-root" class="md-nav__link">maths</a><nav class="md-nav">
+              <ul class="md-nav__list">
+        <li class="md-nav__item"><a href="#graphnet.utilities.maths.eps_like" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">eps_like()</span></code></a>
+        </li></ul>
+            </nav>
+        </li>
   </ul>
 </nav>
               </div>
@@ -432,8 +453,22 @@
         <div class="md-content">
           <article class="md-content__inner md-typeset" role="main">
             
-  <section id="maths">
-<h1 id="api-graphnet-utilities-maths--page-root">maths<a class="headerlink" href="#api-graphnet-utilities-maths--page-root" title="Link to this heading">¶</a></h1>
+  <section id="module-graphnet.utilities.maths">
+<span id="maths"></span><h1 id="api-graphnet-utilities-maths--page-root">maths<a class="headerlink" href="#api-graphnet-utilities-maths--page-root" title="Link to this heading">¶</a></h1>
+<p>Collection of assorted “maths-like” functions.</p>
+<dl class="py function">
+<dt class="sig sig-object py" id="graphnet.utilities.maths.eps_like">
+<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.maths.</span></span><span class="sig-name descname"><span class="pre">eps_like</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensor</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/maths.html#eps_like"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.maths.eps_like" title="Link to this definition">¶</a></dt>
+<dd><p>Return <cite>eps</cite> matching <cite>tensor</cite>’s dtype.</p>
+<dl class="field-list simple">
+<dt class="field-odd">Return type<span class="colon">:</span></dt>
+<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p>
+</dd>
+<dt class="field-even">Parameters<span class="colon">:</span></dt>
+<dd class="field-even"><p><strong>tensor</strong> (<em>Tensor</em>) – </p>
+</dd>
+</dl>
+</dd></dl>
 </section>
 
 
@@ -483,7 +518,7 @@ <h1 id="api-graphnet-utilities-maths--page-root">maths<a class="headerlink" href
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/api/modules.html b/api/modules.html
index 6f49ef2ab..0bc19ad98 100644
--- a/api/modules.html
+++ b/api/modules.html
@@ -362,7 +362,7 @@ <h1 id="api-modules--page-root">src<a class="headerlink" href="#api-modules--pag
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/contribute.html b/contribute.html
index 0086e93f9..1a6ccae8d 100644
--- a/contribute.html
+++ b/contribute.html
@@ -484,7 +484,7 @@ <h2 id="code-quality">Code quality<a class="headerlink" href="#code-quality" tit
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/genindex.html b/genindex.html
index 360c7a1ef..7ef7dc30a 100644
--- a/genindex.html
+++ b/genindex.html
@@ -339,23 +339,44 @@ <h1 id="index">Index</h1>
  | <a href="#N"><strong>N</strong></a>
  | <a href="#O"><strong>O</strong></a>
  | <a href="#P"><strong>P</strong></a>
+ | <a href="#Q"><strong>Q</strong></a>
  | <a href="#R"><strong>R</strong></a>
  | <a href="#S"><strong>S</strong></a>
  | <a href="#T"><strong>T</strong></a>
  | <a href="#U"><strong>U</strong></a>
+ | <a href="#V"><strong>V</strong></a>
  | <a href="#W"><strong>W</strong></a>
+ | <a href="#Z"><strong>Z</strong></a>
  
 </div>
 <h2 id="A">A</h2>
 <table style="width: 100%" class="indextable genindextable"><tr>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.add_label">add_label() (graphnet.data.dataset.dataset.Dataset method)</a>
+</li>
       <li><a href="api/graphnet.data.sqlite.sqlite_dataconverter.html#graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.any_pulsemap_is_non_empty">any_pulsemap_is_non_empty() (graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter method)</a>
 </li>
-  </ul></td>
-  <td style="width: 33%; vertical-align: top;"><ul>
       <li><a href="api/graphnet.utilities.argparse.html#graphnet.utilities.argparse.ArgumentParser">ArgumentParser (class in graphnet.utilities.argparse)</a>
 </li>
+      <li><a href="api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig.arguments">arguments (graphnet.utilities.config.model_config.ModelConfig attribute)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.as_dict">as_dict() (graphnet.utilities.config.base_config.BaseConfig method)</a>
+
+      <ul>
+        <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.as_dict">(graphnet.utilities.config.dataset_config.DatasetConfig method)</a>
+</li>
+        <li><a href="api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig.as_dict">(graphnet.utilities.config.model_config.ModelConfig method)</a>
+</li>
+      </ul></li>
+  </ul></td>
+  <td style="width: 33%; vertical-align: top;"><ul>
       <li><a href="api/graphnet.data.sqlite.sqlite_utilities.html#graphnet.data.sqlite.sqlite_utilities.attach_index">attach_index() (in module graphnet.data.sqlite.sqlite_utilities)</a>
+</li>
+      <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.AttributeCoarsening">AttributeCoarsening (class in graphnet.models.coarsening)</a>
+</li>
+      <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstruction">AzimuthReconstruction (class in graphnet.models.task.reconstruction)</a>
+</li>
+      <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa">AzimuthReconstructionWithKappa (class in graphnet.models.task.reconstruction)</a>
 </li>
   </ul></td>
 </tr></table>
@@ -363,10 +384,20 @@ <h2 id="A">A</h2>
 <h2 id="B">B</h2>
 <table style="width: 100%" class="indextable genindextable"><tr>
   <td style="width: 33%; vertical-align: top;"><ul>
-      <li><a href="api/graphnet.training.weight_fitting.html#graphnet.training.weight_fitting.BjoernLow">BjoernLow (class in graphnet.training.weight_fitting)</a>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCMK.backward">backward() (graphnet.training.loss_functions.LogCMK static method)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig">BaseConfig (class in graphnet.utilities.config.base_config)</a>
+</li>
+      <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTask">BinaryClassificationTask (class in graphnet.models.task.classification)</a>
 </li>
   </ul></td>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTaskLogits">BinaryClassificationTaskLogits (class in graphnet.models.task.classification)</a>
+</li>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.BinaryCrossEntropyLoss">BinaryCrossEntropyLoss (class in graphnet.training.loss_functions)</a>
+</li>
+      <li><a href="api/graphnet.training.weight_fitting.html#graphnet.training.weight_fitting.BjoernLow">BjoernLow (class in graphnet.training.weight_fitting)</a>
+</li>
       <li><a href="api/graphnet.data.extractors.utilities.types.html#graphnet.data.extractors.utilities.types.break_cyclic_recursion">break_cyclic_recursion() (in module graphnet.data.extractors.utilities.types)</a>
 </li>
   </ul></td>
@@ -376,26 +407,62 @@ <h2 id="C">C</h2>
 <table style="width: 100%" class="indextable genindextable"><tr>
   <td style="width: 33%; vertical-align: top;"><ul>
       <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.cache_output_files">cache_output_files() (in module graphnet.data.dataconverter)</a>
+</li>
+      <li><a href="api/graphnet.models.utils.html#graphnet.models.utils.calculate_distance_matrix">calculate_distance_matrix() (in module graphnet.models.utils)</a>
+</li>
+      <li><a href="api/graphnet.models.utils.html#graphnet.models.utils.calculate_xyzt_homophily">calculate_xyzt_homophily() (in module graphnet.models.utils)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.utilities.types.html#graphnet.data.extractors.utilities.types.cast_object_to_pure_python">cast_object_to_pure_python() (in module graphnet.data.extractors.utilities.types)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.utilities.types.html#graphnet.data.extractors.utilities.types.cast_pulse_series_to_pure_python">cast_pulse_series_to_pure_python() (in module graphnet.data.extractors.utilities.types)</a>
 </li>
-      <li><a href="api/graphnet.pisa.fitting.html#graphnet.pisa.fitting.config_updater">config_updater() (in module graphnet.pisa.fitting)</a>
+      <li><a href="api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig.class_name">class_name (graphnet.utilities.config.model_config.ModelConfig attribute)</a>
 </li>
-      <li><a href="api/graphnet.data.sqlite.sqlite_dataconverter.html#graphnet.data.sqlite.sqlite_dataconverter.construct_dataframe">construct_dataframe() (in module graphnet.data.sqlite.sqlite_dataconverter)</a>
+      <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.Coarsening">Coarsening (class in graphnet.models.coarsening)</a>
+</li>
+      <li><a href="api/graphnet.data.dataloader.html#graphnet.data.dataloader.collate_fn">collate_fn() (in module graphnet.data.dataloader)</a>
+
+      <ul>
+        <li><a href="api/graphnet.training.utils.html#graphnet.training.utils.collate_fn">(in module graphnet.training.utils)</a>
+</li>
+      </ul></li>
+      <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.ColumnMissingException">ColumnMissingException</a>
+</li>
+      <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.compute_loss">compute_loss() (graphnet.models.standard_model.StandardModel method)</a>
+
+      <ul>
+        <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task.compute_loss">(graphnet.models.task.task.Task method)</a>
+</li>
+      </ul></li>
+      <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.concatenate">concatenate() (graphnet.data.dataset.dataset.Dataset class method)</a>
 </li>
   </ul></td>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable.config">config (graphnet.utilities.config.configurable.Configurable property)</a>
+</li>
+      <li><a href="api/graphnet.pisa.fitting.html#graphnet.pisa.fitting.config_updater">config_updater() (in module graphnet.pisa.fitting)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable">Configurable (class in graphnet.utilities.config.configurable)</a>
+</li>
+      <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.configure_optimizers">configure_optimizers() (graphnet.models.standard_model.StandardModel method)</a>
+</li>
+      <li><a href="api/graphnet.data.sqlite.sqlite_dataconverter.html#graphnet.data.sqlite.sqlite_dataconverter.construct_dataframe">construct_dataframe() (in module graphnet.data.sqlite.sqlite_dataconverter)</a>
+</li>
       <li><a href="api/graphnet.utilities.argparse.html#graphnet.utilities.argparse.Options.contains">contains() (graphnet.utilities.argparse.Options method)</a>
 </li>
       <li><a href="api/graphnet.pisa.fitting.html#graphnet.pisa.fitting.ContourFitter">ContourFitter (class in graphnet.pisa.fitting)</a>
+</li>
+      <li><a href="api/graphnet.models.gnn.convnet.html#graphnet.models.gnn.convnet.ConvNet">ConvNet (class in graphnet.models.gnn.convnet)</a>
 </li>
       <li><a href="api/graphnet.data.sqlite.sqlite_utilities.html#graphnet.data.sqlite.sqlite_utilities.create_table">create_table() (in module graphnet.data.sqlite.sqlite_utilities)</a>
 </li>
       <li><a href="api/graphnet.data.sqlite.sqlite_utilities.html#graphnet.data.sqlite.sqlite_utilities.create_table_and_save_to_sql">create_table_and_save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)</a>
 </li>
       <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.critical">critical() (graphnet.utilities.logging.Logger method)</a>
+</li>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.CrossEntropyLoss">CrossEntropyLoss (class in graphnet.training.loss_functions)</a>
+</li>
+      <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.CustomDOMCoarsening">CustomDOMCoarsening (class in graphnet.models.coarsening)</a>
 </li>
   </ul></td>
 </tr></table>
@@ -409,8 +476,14 @@ <h2 id="D">D</h2>
 </li>
       <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.DataConverter">DataConverter (class in graphnet.data.dataconverter)</a>
 </li>
-  </ul></td>
-  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.data.dataloader.html#graphnet.data.dataloader.DataLoader">DataLoader (class in graphnet.data.dataloader)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig.dataloader">dataloader (graphnet.utilities.config.training_config.TrainingConfig attribute)</a>
+</li>
+      <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset">Dataset (class in graphnet.data.dataset.dataset)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig">DatasetConfig (class in graphnet.utilities.config.dataset_config)</a>
+</li>
       <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.debug">debug() (graphnet.utilities.logging.Logger method)</a>
 </li>
       <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.FEATURES.DEEPCORE">DEEPCORE (graphnet.data.constants.FEATURES attribute)</a>
@@ -419,16 +492,130 @@ <h2 id="D">D</h2>
         <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.TRUTH.DEEPCORE">(graphnet.data.constants.TRUTH attribute)</a>
 </li>
       </ul></li>
+      <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels">default_prediction_labels (graphnet.models.task.classification.BinaryClassificationTask attribute)</a>
+
+      <ul>
+        <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels">(graphnet.models.task.classification.BinaryClassificationTaskLogits attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels">(graphnet.models.task.reconstruction.AzimuthReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels">(graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels">(graphnet.models.task.reconstruction.DirectionReconstructionWithKappa attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels">(graphnet.models.task.reconstruction.EnergyReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels">(graphnet.models.task.reconstruction.EnergyReconstructionWithPower attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels">(graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels">(graphnet.models.task.reconstruction.InelasticityReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels">(graphnet.models.task.reconstruction.PositionReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels">(graphnet.models.task.reconstruction.TimeReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels">(graphnet.models.task.reconstruction.VertexReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels">(graphnet.models.task.reconstruction.ZenithReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels">(graphnet.models.task.reconstruction.ZenithReconstructionWithKappa attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.IdentityTask.default_prediction_labels">(graphnet.models.task.task.IdentityTask property)</a>
+</li>
+        <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task.default_prediction_labels">(graphnet.models.task.task.Task property)</a>
+</li>
+      </ul></li>
+      <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTask.default_target_labels">default_target_labels (graphnet.models.task.classification.BinaryClassificationTask attribute)</a>
+
+      <ul>
+        <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels">(graphnet.models.task.classification.BinaryClassificationTaskLogits attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels">(graphnet.models.task.reconstruction.AzimuthReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels">(graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels">(graphnet.models.task.reconstruction.DirectionReconstructionWithKappa attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels">(graphnet.models.task.reconstruction.EnergyReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels">(graphnet.models.task.reconstruction.EnergyReconstructionWithPower attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels">(graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels">(graphnet.models.task.reconstruction.InelasticityReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels">(graphnet.models.task.reconstruction.PositionReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels">(graphnet.models.task.reconstruction.TimeReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels">(graphnet.models.task.reconstruction.VertexReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels">(graphnet.models.task.reconstruction.ZenithReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels">(graphnet.models.task.reconstruction.ZenithReconstructionWithKappa attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.IdentityTask.default_target_labels">(graphnet.models.task.task.IdentityTask property)</a>
+</li>
+        <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task.default_target_labels">(graphnet.models.task.task.Task property)</a>
+</li>
+      </ul></li>
+  </ul></td>
+  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector">Detector (class in graphnet.models.detector.detector)</a>
+</li>
+      <li><a href="api/graphnet.training.labels.html#graphnet.training.labels.Direction">Direction (class in graphnet.training.labels)</a>
+</li>
+      <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa">DirectionReconstructionWithKappa (class in graphnet.models.task.reconstruction)</a>
+</li>
+      <li><a href="api/graphnet.data.dataloader.html#graphnet.data.dataloader.do_shuffle">do_shuffle() (in module graphnet.data.dataloader)</a>
+</li>
+      <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.DOMAndTimeWindowCoarsening">DOMAndTimeWindowCoarsening (class in graphnet.models.coarsening)</a>
+</li>
+      <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.DOMCoarsening">DOMCoarsening (class in graphnet.models.coarsening)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.dump">dump() (graphnet.utilities.config.base_config.BaseConfig method)</a>
+</li>
+      <li><a href="api/graphnet.models.gnn.dynedge.html#graphnet.models.gnn.dynedge.DynEdge">DynEdge (class in graphnet.models.gnn.dynedge)</a>
+</li>
+      <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynEdgeConv">DynEdgeConv (class in graphnet.models.components.layers)</a>
+</li>
+      <li><a href="api/graphnet.models.gnn.dynedge_jinst.html#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST">DynEdgeJINST (class in graphnet.models.gnn.dynedge_jinst)</a>
+</li>
+      <li><a href="api/graphnet.models.gnn.dynedge_kaggle_tito.html#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO">DynEdgeTITO (class in graphnet.models.gnn.dynedge_kaggle_tito)</a>
+</li>
+      <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynTrans">DynTrans (class in graphnet.models.components.layers)</a>
+</li>
   </ul></td>
 </tr></table>
 
 <h2 id="E">E</h2>
 <table style="width: 100%" class="indextable genindextable"><tr>
   <td style="width: 33%; vertical-align: top;"><ul>
-      <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.error">error() (graphnet.utilities.logging.Logger method)</a>
+      <li><a href="api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience">early_stopping_patience (graphnet.utilities.config.training_config.TrainingConfig attribute)</a>
+</li>
+      <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito">EdgeConvTito (class in graphnet.models.components.layers)</a>
+</li>
+      <li><a href="api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EdgeDefinition">EdgeDefinition (class in graphnet.models.graphs.edges.edges)</a>
+</li>
+      <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstruction">EnergyReconstruction (class in graphnet.models.task.reconstruction)</a>
+</li>
+      <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithPower">EnergyReconstructionWithPower (class in graphnet.models.task.reconstruction)</a>
+</li>
+      <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty">EnergyReconstructionWithUncertainty (class in graphnet.models.task.reconstruction)</a>
 </li>
   </ul></td>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.EnsembleDataset">EnsembleDataset (class in graphnet.data.dataset.dataset)</a>
+</li>
+      <li><a href="api/graphnet.utilities.maths.html#graphnet.utilities.maths.eps_like">eps_like() (in module graphnet.utilities.maths)</a>
+</li>
+      <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.error">error() (graphnet.utilities.logging.Logger method)</a>
+</li>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.EuclideanDistanceLoss">EuclideanDistanceLoss (class in graphnet.training.loss_functions)</a>
+</li>
+      <li><a href="api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EuclideanEdges">EuclideanEdges (class in graphnet.models.graphs.edges.edges)</a>
+</li>
       <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.DataConverter.execute">execute() (graphnet.data.dataconverter.DataConverter method)</a>
 </li>
   </ul></td>
@@ -437,7 +624,23 @@ <h2 id="E">E</h2>
 <h2 id="F">F</h2>
 <table style="width: 100%" class="indextable genindextable"><tr>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector.feature_map">feature_map() (graphnet.models.detector.detector.Detector method)</a>
+
+      <ul>
+        <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCube86.feature_map">(graphnet.models.detector.icecube.IceCube86 method)</a>
+</li>
+        <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeDeepCore.feature_map">(graphnet.models.detector.icecube.IceCubeDeepCore method)</a>
+</li>
+        <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeKaggle.feature_map">(graphnet.models.detector.icecube.IceCubeKaggle method)</a>
+</li>
+        <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeUpgrade.feature_map">(graphnet.models.detector.icecube.IceCubeUpgrade method)</a>
+</li>
+        <li><a href="api/graphnet.models.detector.prometheus.html#graphnet.models.detector.prometheus.Prometheus.feature_map">(graphnet.models.detector.prometheus.Prometheus method)</a>
+</li>
+      </ul></li>
       <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.FEATURES">FEATURES (class in graphnet.data.constants)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.features">features (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
 </li>
       <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.file_handlers">file_handlers (graphnet.utilities.logging.Logger property)</a>
 </li>
@@ -453,12 +656,16 @@ <h2 id="F">F</h2>
 </li>
       <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.RepeatFilter.filter">filter() (graphnet.utilities.logging.RepeatFilter method)</a>
 </li>
-  </ul></td>
-  <td style="width: 33%; vertical-align: top;"><ul>
       <li><a href="api/graphnet.utilities.filesys.html#graphnet.utilities.filesys.find_i3_files">find_i3_files() (in module graphnet.utilities.filesys)</a>
 </li>
-      <li><a href="api/graphnet.training.weight_fitting.html#graphnet.training.weight_fitting.WeightFitter.fit">fit() (graphnet.training.weight_fitting.WeightFitter method)</a>
+      <li><a href="api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig.fit">fit (graphnet.utilities.config.training_config.TrainingConfig attribute)</a>
 </li>
+      <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.fit">fit() (graphnet.models.model.Model method)</a>
+
+      <ul>
+        <li><a href="api/graphnet.training.weight_fitting.html#graphnet.training.weight_fitting.WeightFitter.fit">(graphnet.training.weight_fitting.WeightFitter method)</a>
+</li>
+      </ul></li>
       <li><a href="api/graphnet.pisa.fitting.html#graphnet.pisa.fitting.ContourFitter.fit_1d_contour">fit_1d_contour() (graphnet.pisa.fitting.ContourFitter method)</a>
 </li>
       <li><a href="api/graphnet.pisa.fitting.html#graphnet.pisa.fitting.ContourFitter.fit_2d_contour">fit_2d_contour() (graphnet.pisa.fitting.ContourFitter method)</a>
@@ -467,9 +674,59 @@ <h2 id="F">F</h2>
 </li>
       <li><a href="api/graphnet.data.extractors.utilities.collections.html#graphnet.data.extractors.utilities.collections.flatten_nested_dictionary">flatten_nested_dictionary() (in module graphnet.data.extractors.utilities.collections)</a>
 </li>
+      <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.Coarsening.forward">forward() (graphnet.models.coarsening.Coarsening method)</a>
+
+      <ul>
+        <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynEdgeConv.forward">(graphnet.models.components.layers.DynEdgeConv method)</a>
+</li>
+        <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynTrans.forward">(graphnet.models.components.layers.DynTrans method)</a>
+</li>
+        <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito.forward">(graphnet.models.components.layers.EdgeConvTito method)</a>
+</li>
+        <li><a href="api/graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector.forward">(graphnet.models.detector.detector.Detector method)</a>
+</li>
+        <li><a href="api/graphnet.models.gnn.convnet.html#graphnet.models.gnn.convnet.ConvNet.forward">(graphnet.models.gnn.convnet.ConvNet method)</a>
+</li>
+        <li><a href="api/graphnet.models.gnn.dynedge.html#graphnet.models.gnn.dynedge.DynEdge.forward">(graphnet.models.gnn.dynedge.DynEdge method)</a>
+</li>
+        <li><a href="api/graphnet.models.gnn.dynedge_jinst.html#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward">(graphnet.models.gnn.dynedge_jinst.DynEdgeJINST method)</a>
+</li>
+        <li><a href="api/graphnet.models.gnn.dynedge_kaggle_tito.html#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward">(graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO method)</a>
+</li>
+        <li><a href="api/graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN.forward">(graphnet.models.gnn.gnn.GNN method)</a>
+</li>
+        <li><a href="api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EdgeDefinition.forward">(graphnet.models.graphs.edges.edges.EdgeDefinition method)</a>
+</li>
+        <li><a href="api/graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition.forward">(graphnet.models.graphs.graph_definition.GraphDefinition method)</a>
+</li>
+        <li><a href="api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition.forward">(graphnet.models.graphs.nodes.nodes.NodeDefinition method)</a>
+</li>
+        <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.forward">(graphnet.models.model.Model method)</a>
+</li>
+        <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.forward">(graphnet.models.standard_model.StandardModel method)</a>
+</li>
+        <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task.forward">(graphnet.models.task.task.Task method)</a>
+</li>
+        <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCMK.forward">(graphnet.training.loss_functions.LogCMK static method)</a>
+</li>
+        <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction.forward">(graphnet.training.loss_functions.LossFunction method)</a>
+</li>
+      </ul></li>
+  </ul></td>
+  <td style="width: 33%; vertical-align: top;"><ul>
       <li><a href="api/graphnet.data.extractors.utilities.frames.html#graphnet.data.extractors.utilities.frames.frame_is_montecarlo">frame_is_montecarlo() (in module graphnet.data.extractors.utilities.frames)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.utilities.frames.html#graphnet.data.extractors.utilities.frames.frame_is_noise">frame_is_noise() (in module graphnet.data.extractors.utilities.frames)</a>
+</li>
+      <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.from_config">from_config() (graphnet.data.dataset.dataset.Dataset class method)</a>
+
+      <ul>
+        <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.from_config">(graphnet.models.model.Model class method)</a>
+</li>
+        <li><a href="api/graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable.from_config">(graphnet.utilities.config.configurable.Configurable class method)</a>
+</li>
+      </ul></li>
+      <li><a href="api/graphnet.data.dataloader.html#graphnet.data.dataloader.DataLoader.from_dataset_config">from_dataset_config() (graphnet.data.dataloader.DataLoader class method)</a>
 </li>
   </ul></td>
 </tr></table>
@@ -478,12 +735,30 @@ <h2 id="G">G</h2>
 <table style="width: 100%" class="indextable genindextable"><tr>
   <td style="width: 33%; vertical-align: top;"><ul>
       <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.FileSet.gcd_file">gcd_file (graphnet.data.dataconverter.FileSet attribute)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.get_all_argument_values">get_all_argument_values() (in module graphnet.utilities.config.base_config)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.get_all_grapnet_classes">get_all_grapnet_classes() (in module graphnet.utilities.config.parsing)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.get_graphnet_classes">get_graphnet_classes() (in module graphnet.utilities.config.parsing)</a>
+</li>
+      <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.PiecewiseLinearLR.get_lr">get_lr() (graphnet.training.callbacks.PiecewiseLinearLR method)</a>
 </li>
       <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.DataConverter.get_map_function">get_map_function() (graphnet.data.dataconverter.DataConverter method)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.utilities.types.html#graphnet.data.extractors.utilities.types.get_member_variables">get_member_variables() (in module graphnet.data.extractors.utilities.types)</a>
+</li>
+      <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.get_metrics">get_metrics() (graphnet.training.callbacks.ProgressBar method)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.utilities.frames.html#graphnet.data.extractors.utilities.frames.get_om_keys_and_pulseseries">get_om_keys_and_pulseseries() (in module graphnet.data.extractors.utilities.frames)</a>
+</li>
+      <li><a href="api/graphnet.training.utils.html#graphnet.training.utils.get_predictions">get_predictions() (in module graphnet.training.utils)</a>
+</li>
+      <li><a href="api/graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN">GNN (class in graphnet.models.gnn.gnn)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition">graph_definition (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
+</li>
+      <li><a href="api/graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition">GraphDefinition (class in graphnet.models.graphs.graph_definition)</a>
 </li>
       <li>
     graphnet
@@ -518,6 +793,62 @@ <h2 id="G">G</h2>
 
       <ul>
         <li><a href="api/graphnet.data.dataconverter.html#module-graphnet.data.dataconverter">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.data.dataloader
+
+      <ul>
+        <li><a href="api/graphnet.data.dataloader.html#module-graphnet.data.dataloader">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.data.dataset
+
+      <ul>
+        <li><a href="api/graphnet.data.dataset.html#module-graphnet.data.dataset">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.data.dataset.dataset
+
+      <ul>
+        <li><a href="api/graphnet.data.dataset.dataset.html#module-graphnet.data.dataset.dataset">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.data.dataset.parquet
+
+      <ul>
+        <li><a href="api/graphnet.data.dataset.parquet.html#module-graphnet.data.dataset.parquet">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.data.dataset.parquet.parquet_dataset
+
+      <ul>
+        <li><a href="api/graphnet.data.dataset.parquet.parquet_dataset.html#module-graphnet.data.dataset.parquet.parquet_dataset">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.data.dataset.sqlite
+
+      <ul>
+        <li><a href="api/graphnet.data.dataset.sqlite.html#module-graphnet.data.dataset.sqlite">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.data.dataset.sqlite.sqlite_dataset
+
+      <ul>
+        <li><a href="api/graphnet.data.dataset.sqlite.sqlite_dataset.html#module-graphnet.data.dataset.sqlite.sqlite_dataset">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.data.dataset.sqlite.sqlite_dataset_perturbed
+
+      <ul>
+        <li><a href="api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html#module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed">module</a>
 </li>
       </ul></li>
       <li>
@@ -632,8 +963,6 @@ <h2 id="G">G</h2>
         <li><a href="api/graphnet.data.extractors.utilities.frames.html#module-graphnet.data.extractors.utilities.frames">module</a>
 </li>
       </ul></li>
-  </ul></td>
-  <td style="width: 33%; vertical-align: top;"><ul>
       <li>
     graphnet.data.extractors.utilities.types
 
@@ -653,6 +982,13 @@ <h2 id="G">G</h2>
 
       <ul>
         <li><a href="api/graphnet.data.parquet.parquet_dataconverter.html#module-graphnet.data.parquet.parquet_dataconverter">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.data.pipeline
+
+      <ul>
+        <li><a href="api/graphnet.data.pipeline.html#module-graphnet.data.pipeline">module</a>
 </li>
       </ul></li>
       <li>
@@ -712,95 +1048,399 @@ <h2 id="G">G</h2>
 </li>
       </ul></li>
       <li>
-    graphnet.pisa
+    graphnet.deployment.i3modules.graphnet_module
 
       <ul>
-        <li><a href="api/graphnet.pisa.html#module-graphnet.pisa">module</a>
+        <li><a href="api/graphnet.deployment.i3modules.graphnet_module.html#module-graphnet.deployment.i3modules.graphnet_module">module</a>
 </li>
       </ul></li>
       <li>
-    graphnet.pisa.fitting
+    graphnet.models
 
       <ul>
-        <li><a href="api/graphnet.pisa.fitting.html#module-graphnet.pisa.fitting">module</a>
+        <li><a href="api/graphnet.models.html#module-graphnet.models">module</a>
 </li>
       </ul></li>
       <li>
-    graphnet.pisa.plotting
+    graphnet.models.coarsening
 
       <ul>
-        <li><a href="api/graphnet.pisa.plotting.html#module-graphnet.pisa.plotting">module</a>
+        <li><a href="api/graphnet.models.coarsening.html#module-graphnet.models.coarsening">module</a>
 </li>
       </ul></li>
       <li>
-    graphnet.training
+    graphnet.models.components
 
       <ul>
-        <li><a href="api/graphnet.training.html#module-graphnet.training">module</a>
+        <li><a href="api/graphnet.models.components.html#module-graphnet.models.components">module</a>
 </li>
       </ul></li>
+  </ul></td>
+  <td style="width: 33%; vertical-align: top;"><ul>
       <li>
-    graphnet.training.weight_fitting
+    graphnet.models.components.layers
 
       <ul>
-        <li><a href="api/graphnet.training.weight_fitting.html#module-graphnet.training.weight_fitting">module</a>
+        <li><a href="api/graphnet.models.components.layers.html#module-graphnet.models.components.layers">module</a>
 </li>
       </ul></li>
       <li>
-    graphnet.utilities
+    graphnet.models.components.pool
 
       <ul>
-        <li><a href="api/graphnet.utilities.html#module-graphnet.utilities">module</a>
+        <li><a href="api/graphnet.models.components.pool.html#module-graphnet.models.components.pool">module</a>
 </li>
       </ul></li>
       <li>
-    graphnet.utilities.argparse
+    graphnet.models.detector
 
       <ul>
-        <li><a href="api/graphnet.utilities.argparse.html#module-graphnet.utilities.argparse">module</a>
+        <li><a href="api/graphnet.models.detector.html#module-graphnet.models.detector">module</a>
 </li>
       </ul></li>
       <li>
-    graphnet.utilities.decorators
+    graphnet.models.detector.detector
 
       <ul>
-        <li><a href="api/graphnet.utilities.decorators.html#module-graphnet.utilities.decorators">module</a>
+        <li><a href="api/graphnet.models.detector.detector.html#module-graphnet.models.detector.detector">module</a>
 </li>
       </ul></li>
       <li>
-    graphnet.utilities.filesys
+    graphnet.models.detector.icecube
 
       <ul>
-        <li><a href="api/graphnet.utilities.filesys.html#module-graphnet.utilities.filesys">module</a>
+        <li><a href="api/graphnet.models.detector.icecube.html#module-graphnet.models.detector.icecube">module</a>
 </li>
       </ul></li>
       <li>
-    graphnet.utilities.imports
+    graphnet.models.detector.prometheus
 
       <ul>
-        <li><a href="api/graphnet.utilities.imports.html#module-graphnet.utilities.imports">module</a>
+        <li><a href="api/graphnet.models.detector.prometheus.html#module-graphnet.models.detector.prometheus">module</a>
 </li>
       </ul></li>
       <li>
-    graphnet.utilities.logging
+    graphnet.models.gnn
 
       <ul>
-        <li><a href="api/graphnet.utilities.logging.html#module-graphnet.utilities.logging">module</a>
+        <li><a href="api/graphnet.models.gnn.html#module-graphnet.models.gnn">module</a>
 </li>
       </ul></li>
-  </ul></td>
-</tr></table>
+      <li>
+    graphnet.models.gnn.convnet
 
-<h2 id="H">H</h2>
-<table style="width: 100%" class="indextable genindextable"><tr>
-  <td style="width: 33%; vertical-align: top;"><ul>
-      <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.handlers">handlers (graphnet.utilities.logging.Logger property)</a>
+      <ul>
+        <li><a href="api/graphnet.models.gnn.convnet.html#module-graphnet.models.gnn.convnet">module</a>
 </li>
-      <li><a href="api/graphnet.utilities.filesys.html#graphnet.utilities.filesys.has_extension">has_extension() (in module graphnet.utilities.filesys)</a>
+      </ul></li>
+      <li>
+    graphnet.models.gnn.dynedge
+
+      <ul>
+        <li><a href="api/graphnet.models.gnn.dynedge.html#module-graphnet.models.gnn.dynedge">module</a>
 </li>
-  </ul></td>
-  <td style="width: 33%; vertical-align: top;"><ul>
-      <li><a href="api/graphnet.utilities.imports.html#graphnet.utilities.imports.has_icecube_package">has_icecube_package() (in module graphnet.utilities.imports)</a>
+      </ul></li>
+      <li>
+    graphnet.models.gnn.dynedge_jinst
+
+      <ul>
+        <li><a href="api/graphnet.models.gnn.dynedge_jinst.html#module-graphnet.models.gnn.dynedge_jinst">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.gnn.dynedge_kaggle_tito
+
+      <ul>
+        <li><a href="api/graphnet.models.gnn.dynedge_kaggle_tito.html#module-graphnet.models.gnn.dynedge_kaggle_tito">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.gnn.gnn
+
+      <ul>
+        <li><a href="api/graphnet.models.gnn.gnn.html#module-graphnet.models.gnn.gnn">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.graphs
+
+      <ul>
+        <li><a href="api/graphnet.models.graphs.html#module-graphnet.models.graphs">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.graphs.edges
+
+      <ul>
+        <li><a href="api/graphnet.models.graphs.edges.html#module-graphnet.models.graphs.edges">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.graphs.edges.edges
+
+      <ul>
+        <li><a href="api/graphnet.models.graphs.edges.edges.html#module-graphnet.models.graphs.edges.edges">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.graphs.graph_definition
+
+      <ul>
+        <li><a href="api/graphnet.models.graphs.graph_definition.html#module-graphnet.models.graphs.graph_definition">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.graphs.graphs
+
+      <ul>
+        <li><a href="api/graphnet.models.graphs.graphs.html#module-graphnet.models.graphs.graphs">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.graphs.nodes
+
+      <ul>
+        <li><a href="api/graphnet.models.graphs.nodes.html#module-graphnet.models.graphs.nodes">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.graphs.nodes.nodes
+
+      <ul>
+        <li><a href="api/graphnet.models.graphs.nodes.nodes.html#module-graphnet.models.graphs.nodes.nodes">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.model
+
+      <ul>
+        <li><a href="api/graphnet.models.model.html#module-graphnet.models.model">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.standard_model
+
+      <ul>
+        <li><a href="api/graphnet.models.standard_model.html#module-graphnet.models.standard_model">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.task
+
+      <ul>
+        <li><a href="api/graphnet.models.task.html#module-graphnet.models.task">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.task.classification
+
+      <ul>
+        <li><a href="api/graphnet.models.task.classification.html#module-graphnet.models.task.classification">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.task.reconstruction
+
+      <ul>
+        <li><a href="api/graphnet.models.task.reconstruction.html#module-graphnet.models.task.reconstruction">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.task.task
+
+      <ul>
+        <li><a href="api/graphnet.models.task.task.html#module-graphnet.models.task.task">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.models.utils
+
+      <ul>
+        <li><a href="api/graphnet.models.utils.html#module-graphnet.models.utils">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.pisa
+
+      <ul>
+        <li><a href="api/graphnet.pisa.html#module-graphnet.pisa">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.pisa.fitting
+
+      <ul>
+        <li><a href="api/graphnet.pisa.fitting.html#module-graphnet.pisa.fitting">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.pisa.plotting
+
+      <ul>
+        <li><a href="api/graphnet.pisa.plotting.html#module-graphnet.pisa.plotting">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.training
+
+      <ul>
+        <li><a href="api/graphnet.training.html#module-graphnet.training">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.training.callbacks
+
+      <ul>
+        <li><a href="api/graphnet.training.callbacks.html#module-graphnet.training.callbacks">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.training.labels
+
+      <ul>
+        <li><a href="api/graphnet.training.labels.html#module-graphnet.training.labels">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.training.loss_functions
+
+      <ul>
+        <li><a href="api/graphnet.training.loss_functions.html#module-graphnet.training.loss_functions">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.training.utils
+
+      <ul>
+        <li><a href="api/graphnet.training.utils.html#module-graphnet.training.utils">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.training.weight_fitting
+
+      <ul>
+        <li><a href="api/graphnet.training.weight_fitting.html#module-graphnet.training.weight_fitting">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.utilities
+
+      <ul>
+        <li><a href="api/graphnet.utilities.html#module-graphnet.utilities">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.utilities.argparse
+
+      <ul>
+        <li><a href="api/graphnet.utilities.argparse.html#module-graphnet.utilities.argparse">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.utilities.config
+
+      <ul>
+        <li><a href="api/graphnet.utilities.config.html#module-graphnet.utilities.config">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.utilities.config.base_config
+
+      <ul>
+        <li><a href="api/graphnet.utilities.config.base_config.html#module-graphnet.utilities.config.base_config">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.utilities.config.configurable
+
+      <ul>
+        <li><a href="api/graphnet.utilities.config.configurable.html#module-graphnet.utilities.config.configurable">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.utilities.config.dataset_config
+
+      <ul>
+        <li><a href="api/graphnet.utilities.config.dataset_config.html#module-graphnet.utilities.config.dataset_config">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.utilities.config.model_config
+
+      <ul>
+        <li><a href="api/graphnet.utilities.config.model_config.html#module-graphnet.utilities.config.model_config">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.utilities.config.parsing
+
+      <ul>
+        <li><a href="api/graphnet.utilities.config.parsing.html#module-graphnet.utilities.config.parsing">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.utilities.config.training_config
+
+      <ul>
+        <li><a href="api/graphnet.utilities.config.training_config.html#module-graphnet.utilities.config.training_config">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.utilities.decorators
+
+      <ul>
+        <li><a href="api/graphnet.utilities.decorators.html#module-graphnet.utilities.decorators">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.utilities.filesys
+
+      <ul>
+        <li><a href="api/graphnet.utilities.filesys.html#module-graphnet.utilities.filesys">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.utilities.imports
+
+      <ul>
+        <li><a href="api/graphnet.utilities.imports.html#module-graphnet.utilities.imports">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.utilities.logging
+
+      <ul>
+        <li><a href="api/graphnet.utilities.logging.html#module-graphnet.utilities.logging">module</a>
+</li>
+      </ul></li>
+      <li>
+    graphnet.utilities.maths
+
+      <ul>
+        <li><a href="api/graphnet.utilities.maths.html#module-graphnet.utilities.maths">module</a>
+</li>
+      </ul></li>
+      <li><a href="api/graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module">GraphNeTI3Module (class in graphnet.deployment.i3modules.graphnet_module)</a>
+</li>
+      <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.group_by">group_by() (in module graphnet.models.components.pool)</a>
+</li>
+      <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.group_pulses_to_dom">group_pulses_to_dom() (in module graphnet.models.components.pool)</a>
+</li>
+      <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.group_pulses_to_pmt">group_pulses_to_pmt() (in module graphnet.models.components.pool)</a>
+</li>
+  </ul></td>
+</tr></table>
+
+<h2 id="H">H</h2>
+<table style="width: 100%" class="indextable genindextable"><tr>
+  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.handlers">handlers (graphnet.utilities.logging.Logger property)</a>
+</li>
+      <li><a href="api/graphnet.utilities.filesys.html#graphnet.utilities.filesys.has_extension">has_extension() (in module graphnet.utilities.filesys)</a>
+</li>
+  </ul></td>
+  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.utilities.imports.html#graphnet.utilities.imports.has_icecube_package">has_icecube_package() (in module graphnet.utilities.imports)</a>
 </li>
       <li><a href="api/graphnet.utilities.imports.html#graphnet.utilities.imports.has_pisa_package">has_pisa_package() (in module graphnet.utilities.imports)</a>
 </li>
@@ -829,12 +1469,16 @@ <h2 id="I">I</h2>
       <li><a href="api/graphnet.data.extractors.i3hybridrecoextractor.html#graphnet.data.extractors.i3hybridrecoextractor.I3GalacticPlaneHybridRecoExtractor">I3GalacticPlaneHybridRecoExtractor (class in graphnet.data.extractors.i3hybridrecoextractor)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.i3genericextractor.html#graphnet.data.extractors.i3genericextractor.I3GenericExtractor">I3GenericExtractor (class in graphnet.data.extractors.i3genericextractor)</a>
+</li>
+      <li><a href="api/graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule">I3InferenceModule (class in graphnet.deployment.i3modules.graphnet_module)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.i3ntmuonlabelsextractor.html#graphnet.data.extractors.i3ntmuonlabelsextractor.I3NTMuonLabelExtractor">I3NTMuonLabelExtractor (class in graphnet.data.extractors.i3ntmuonlabelsextractor)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.i3particleextractor.html#graphnet.data.extractors.i3particleextractor.I3ParticleExtractor">I3ParticleExtractor (class in graphnet.data.extractors.i3particleextractor)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.i3pisaextractor.html#graphnet.data.extractors.i3pisaextractor.I3PISAExtractor">I3PISAExtractor (class in graphnet.data.extractors.i3pisaextractor)</a>
+</li>
+      <li><a href="api/graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule">I3PulseCleanerModule (class in graphnet.deployment.i3modules.graphnet_module)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.i3featureextractor.html#graphnet.data.extractors.i3featureextractor.I3PulseNoiseTruthFlagIceCubeUpgrade">I3PulseNoiseTruthFlagIceCubeUpgrade (class in graphnet.data.extractors.i3featureextractor)</a>
 </li>
@@ -842,29 +1486,63 @@ <h2 id="I">I</h2>
 </li>
       <li><a href="api/graphnet.data.extractors.i3retroextractor.html#graphnet.data.extractors.i3retroextractor.I3RetroExtractor">I3RetroExtractor (class in graphnet.data.extractors.i3retroextractor)</a>
 </li>
-  </ul></td>
-  <td style="width: 33%; vertical-align: top;"><ul>
       <li><a href="api/graphnet.data.extractors.i3splinempeextractor.html#graphnet.data.extractors.i3splinempeextractor.I3SplineMPEICExtractor">I3SplineMPEICExtractor (class in graphnet.data.extractors.i3splinempeextractor)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.i3truthextractor.html#graphnet.data.extractors.i3truthextractor.I3TruthExtractor">I3TruthExtractor (class in graphnet.data.extractors.i3truthextractor)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.i3tumextractor.html#graphnet.data.extractors.i3tumextractor.I3TUMExtractor">I3TUMExtractor (class in graphnet.data.extractors.i3tumextractor)</a>
+</li>
+      <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCube86">IceCube86 (class in graphnet.models.detector.icecube)</a>
 </li>
       <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.FEATURES.ICECUBE86">ICECUBE86 (graphnet.data.constants.FEATURES attribute)</a>
 
       <ul>
         <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.TRUTH.ICECUBE86">(graphnet.data.constants.TRUTH attribute)</a>
+</li>
+      </ul></li>
+      <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeDeepCore">IceCubeDeepCore (class in graphnet.models.detector.icecube)</a>
+</li>
+  </ul></td>
+  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeKaggle">IceCubeKaggle (class in graphnet.models.detector.icecube)</a>
+</li>
+      <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeUpgrade">IceCubeUpgrade (class in graphnet.models.detector.icecube)</a>
+</li>
+      <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.IdentityTask">IdentityTask (class in graphnet.models.task.task)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.index_column">index_column (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
+</li>
+      <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.InelasticityReconstruction">InelasticityReconstruction (class in graphnet.models.task.reconstruction)</a>
+</li>
+      <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.inference">inference() (graphnet.models.standard_model.StandardModel method)</a>
+
+      <ul>
+        <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task.inference">(graphnet.models.task.task.Task method)</a>
 </li>
       </ul></li>
       <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.info">info() (graphnet.utilities.logging.Logger method)</a>
 </li>
       <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.init_global_index">init_global_index() (in module graphnet.data.dataconverter)</a>
+</li>
+      <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_predict_tqdm">init_predict_tqdm() (graphnet.training.callbacks.ProgressBar method)</a>
+</li>
+      <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_test_tqdm">init_test_tqdm() (graphnet.training.callbacks.ProgressBar method)</a>
+</li>
+      <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_train_tqdm">init_train_tqdm() (graphnet.training.callbacks.ProgressBar method)</a>
+</li>
+      <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_validation_tqdm">init_validation_tqdm() (graphnet.training.callbacks.ProgressBar method)</a>
+</li>
+      <li><a href="api/graphnet.data.pipeline.html#graphnet.data.pipeline.InSQLitePipeline">InSQLitePipeline (class in graphnet.data.pipeline)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.utilities.types.html#graphnet.data.extractors.utilities.types.is_boost_class">is_boost_class() (in module graphnet.data.extractors.utilities.types)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.utilities.types.html#graphnet.data.extractors.utilities.types.is_boost_enum">is_boost_enum() (in module graphnet.data.extractors.utilities.types)</a>
 </li>
       <li><a href="api/graphnet.utilities.filesys.html#graphnet.utilities.filesys.is_gcd_file">is_gcd_file() (in module graphnet.utilities.filesys)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.is_graphnet_class">is_graphnet_class() (in module graphnet.utilities.config.parsing)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.is_graphnet_module">is_graphnet_module() (in module graphnet.utilities.config.parsing)</a>
 </li>
       <li><a href="api/graphnet.utilities.filesys.html#graphnet.utilities.filesys.is_i3_file">is_i3_file() (in module graphnet.utilities.filesys)</a>
 </li>
@@ -890,13 +1568,57 @@ <h2 id="K">K</h2>
         <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.TRUTH.KAGGLE">(graphnet.data.constants.TRUTH attribute)</a>
 </li>
       </ul></li>
+      <li><a href="api/graphnet.training.labels.html#graphnet.training.labels.Label.key">key (graphnet.training.labels.Label property)</a>
+</li>
+  </ul></td>
+  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.utils.html#graphnet.models.utils.knn_graph_batch">knn_graph_batch() (in module graphnet.models.utils)</a>
+</li>
+      <li><a href="api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.KNNEdges">KNNEdges (class in graphnet.models.graphs.edges.edges)</a>
+</li>
+      <li><a href="api/graphnet.models.graphs.graphs.html#graphnet.models.graphs.graphs.KNNGraph">KNNGraph (class in graphnet.models.graphs.graphs)</a>
+</li>
   </ul></td>
 </tr></table>
 
 <h2 id="L">L</h2>
 <table style="width: 100%" class="indextable genindextable"><tr>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.training.labels.html#graphnet.training.labels.Label">Label (class in graphnet.training.labels)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.list_all_submodules">list_all_submodules() (in module graphnet.utilities.config.parsing)</a>
+</li>
+      <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.load">load() (graphnet.models.model.Model class method)</a>
+
+      <ul>
+        <li><a href="api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.load">(graphnet.utilities.config.base_config.BaseConfig class method)</a>
+</li>
+      </ul></li>
+      <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.load_module">load_module() (in module graphnet.data.dataset.dataset)</a>
+</li>
+      <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.load_state_dict">load_state_dict() (graphnet.models.model.Model method)</a>
+</li>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk">log_cmk() (graphnet.training.loss_functions.VonMisesFisherLoss class method)</a>
+</li>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx">log_cmk_approx() (graphnet.training.loss_functions.VonMisesFisherLoss class method)</a>
+</li>
+  </ul></td>
+  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact">log_cmk_exact() (graphnet.training.loss_functions.VonMisesFisherLoss class method)</a>
+</li>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCMK">LogCMK (class in graphnet.training.loss_functions)</a>
+</li>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCoshLoss">LogCoshLoss (class in graphnet.training.loss_functions)</a>
+</li>
       <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger">Logger (class in graphnet.utilities.logging)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column">loss_weight_column (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value">loss_weight_default_value (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table">loss_weight_table (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
+</li>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction">LossFunction (class in graphnet.training.loss_functions)</a>
 </li>
   </ul></td>
 </tr></table>
@@ -904,6 +1626,10 @@ <h2 id="L">L</h2>
 <h2 id="M">M</h2>
 <table style="width: 100%" class="indextable genindextable"><tr>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.training.utils.html#graphnet.training.utils.make_dataloader">make_dataloader() (in module graphnet.training.utils)</a>
+</li>
+      <li><a href="api/graphnet.training.utils.html#graphnet.training.utils.make_train_validation_dataloader">make_train_validation_dataloader() (in module graphnet.training.utils)</a>
+</li>
       <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.DataConverter.merge_files">merge_files() (graphnet.data.dataconverter.DataConverter method)</a>
 
       <ul>
@@ -912,6 +1638,36 @@ <h2 id="M">M</h2>
         <li><a href="api/graphnet.data.sqlite.sqlite_dataconverter.html#graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.merge_files">(graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter method)</a>
 </li>
       </ul></li>
+      <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito.message">message() (graphnet.models.components.layers.EdgeConvTito method)</a>
+</li>
+      <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.min_pool">min_pool() (in module graphnet.models.components.pool)</a>
+</li>
+      <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.min_pool_x">min_pool_x() (in module graphnet.models.components.pool)</a>
+</li>
+      <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model">Model (class in graphnet.models.model)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.model_config">model_config (graphnet.utilities.config.base_config.BaseConfig attribute)</a>
+
+      <ul>
+        <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.model_config">(graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
+</li>
+        <li><a href="api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig.model_config">(graphnet.utilities.config.model_config.ModelConfig attribute)</a>
+</li>
+        <li><a href="api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig.model_config">(graphnet.utilities.config.training_config.TrainingConfig attribute)</a>
+</li>
+      </ul></li>
+      <li><a href="api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.model_fields">model_fields (graphnet.utilities.config.base_config.BaseConfig attribute)</a>
+
+      <ul>
+        <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.model_fields">(graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
+</li>
+        <li><a href="api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig.model_fields">(graphnet.utilities.config.model_config.ModelConfig attribute)</a>
+</li>
+        <li><a href="api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig.model_fields">(graphnet.utilities.config.training_config.TrainingConfig attribute)</a>
+</li>
+      </ul></li>
+      <li><a href="api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig">ModelConfig (class in graphnet.utilities.config.model_config)</a>
+</li>
       <li>
     module
 
@@ -925,6 +1681,22 @@ <h2 id="M">M</h2>
         <li><a href="api/graphnet.data.constants.html#module-graphnet.data.constants">graphnet.data.constants</a>
 </li>
         <li><a href="api/graphnet.data.dataconverter.html#module-graphnet.data.dataconverter">graphnet.data.dataconverter</a>
+</li>
+        <li><a href="api/graphnet.data.dataloader.html#module-graphnet.data.dataloader">graphnet.data.dataloader</a>
+</li>
+        <li><a href="api/graphnet.data.dataset.html#module-graphnet.data.dataset">graphnet.data.dataset</a>
+</li>
+        <li><a href="api/graphnet.data.dataset.dataset.html#module-graphnet.data.dataset.dataset">graphnet.data.dataset.dataset</a>
+</li>
+        <li><a href="api/graphnet.data.dataset.parquet.html#module-graphnet.data.dataset.parquet">graphnet.data.dataset.parquet</a>
+</li>
+        <li><a href="api/graphnet.data.dataset.parquet.parquet_dataset.html#module-graphnet.data.dataset.parquet.parquet_dataset">graphnet.data.dataset.parquet.parquet_dataset</a>
+</li>
+        <li><a href="api/graphnet.data.dataset.sqlite.html#module-graphnet.data.dataset.sqlite">graphnet.data.dataset.sqlite</a>
+</li>
+        <li><a href="api/graphnet.data.dataset.sqlite.sqlite_dataset.html#module-graphnet.data.dataset.sqlite.sqlite_dataset">graphnet.data.dataset.sqlite.sqlite_dataset</a>
+</li>
+        <li><a href="api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html#module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed">graphnet.data.dataset.sqlite.sqlite_dataset_perturbed</a>
 </li>
         <li><a href="api/graphnet.data.extractors.html#module-graphnet.data.extractors">graphnet.data.extractors</a>
 </li>
@@ -963,6 +1735,8 @@ <h2 id="M">M</h2>
         <li><a href="api/graphnet.data.parquet.html#module-graphnet.data.parquet">graphnet.data.parquet</a>
 </li>
         <li><a href="api/graphnet.data.parquet.parquet_dataconverter.html#module-graphnet.data.parquet.parquet_dataconverter">graphnet.data.parquet.parquet_dataconverter</a>
+</li>
+        <li><a href="api/graphnet.data.pipeline.html#module-graphnet.data.pipeline">graphnet.data.pipeline</a>
 </li>
         <li><a href="api/graphnet.data.sqlite.html#module-graphnet.data.sqlite">graphnet.data.sqlite</a>
 </li>
@@ -979,6 +1753,66 @@ <h2 id="M">M</h2>
         <li><a href="api/graphnet.data.utilities.string_selection_resolver.html#module-graphnet.data.utilities.string_selection_resolver">graphnet.data.utilities.string_selection_resolver</a>
 </li>
         <li><a href="api/graphnet.deployment.html#module-graphnet.deployment">graphnet.deployment</a>
+</li>
+        <li><a href="api/graphnet.deployment.i3modules.graphnet_module.html#module-graphnet.deployment.i3modules.graphnet_module">graphnet.deployment.i3modules.graphnet_module</a>
+</li>
+        <li><a href="api/graphnet.models.html#module-graphnet.models">graphnet.models</a>
+</li>
+        <li><a href="api/graphnet.models.coarsening.html#module-graphnet.models.coarsening">graphnet.models.coarsening</a>
+</li>
+        <li><a href="api/graphnet.models.components.html#module-graphnet.models.components">graphnet.models.components</a>
+</li>
+        <li><a href="api/graphnet.models.components.layers.html#module-graphnet.models.components.layers">graphnet.models.components.layers</a>
+</li>
+        <li><a href="api/graphnet.models.components.pool.html#module-graphnet.models.components.pool">graphnet.models.components.pool</a>
+</li>
+        <li><a href="api/graphnet.models.detector.html#module-graphnet.models.detector">graphnet.models.detector</a>
+</li>
+        <li><a href="api/graphnet.models.detector.detector.html#module-graphnet.models.detector.detector">graphnet.models.detector.detector</a>
+</li>
+        <li><a href="api/graphnet.models.detector.icecube.html#module-graphnet.models.detector.icecube">graphnet.models.detector.icecube</a>
+</li>
+        <li><a href="api/graphnet.models.detector.prometheus.html#module-graphnet.models.detector.prometheus">graphnet.models.detector.prometheus</a>
+</li>
+        <li><a href="api/graphnet.models.gnn.html#module-graphnet.models.gnn">graphnet.models.gnn</a>
+</li>
+        <li><a href="api/graphnet.models.gnn.convnet.html#module-graphnet.models.gnn.convnet">graphnet.models.gnn.convnet</a>
+</li>
+        <li><a href="api/graphnet.models.gnn.dynedge.html#module-graphnet.models.gnn.dynedge">graphnet.models.gnn.dynedge</a>
+</li>
+        <li><a href="api/graphnet.models.gnn.dynedge_jinst.html#module-graphnet.models.gnn.dynedge_jinst">graphnet.models.gnn.dynedge_jinst</a>
+</li>
+        <li><a href="api/graphnet.models.gnn.dynedge_kaggle_tito.html#module-graphnet.models.gnn.dynedge_kaggle_tito">graphnet.models.gnn.dynedge_kaggle_tito</a>
+</li>
+        <li><a href="api/graphnet.models.gnn.gnn.html#module-graphnet.models.gnn.gnn">graphnet.models.gnn.gnn</a>
+</li>
+        <li><a href="api/graphnet.models.graphs.html#module-graphnet.models.graphs">graphnet.models.graphs</a>
+</li>
+        <li><a href="api/graphnet.models.graphs.edges.html#module-graphnet.models.graphs.edges">graphnet.models.graphs.edges</a>
+</li>
+        <li><a href="api/graphnet.models.graphs.edges.edges.html#module-graphnet.models.graphs.edges.edges">graphnet.models.graphs.edges.edges</a>
+</li>
+        <li><a href="api/graphnet.models.graphs.graph_definition.html#module-graphnet.models.graphs.graph_definition">graphnet.models.graphs.graph_definition</a>
+</li>
+        <li><a href="api/graphnet.models.graphs.graphs.html#module-graphnet.models.graphs.graphs">graphnet.models.graphs.graphs</a>
+</li>
+        <li><a href="api/graphnet.models.graphs.nodes.html#module-graphnet.models.graphs.nodes">graphnet.models.graphs.nodes</a>
+</li>
+        <li><a href="api/graphnet.models.graphs.nodes.nodes.html#module-graphnet.models.graphs.nodes.nodes">graphnet.models.graphs.nodes.nodes</a>
+</li>
+        <li><a href="api/graphnet.models.model.html#module-graphnet.models.model">graphnet.models.model</a>
+</li>
+        <li><a href="api/graphnet.models.standard_model.html#module-graphnet.models.standard_model">graphnet.models.standard_model</a>
+</li>
+        <li><a href="api/graphnet.models.task.html#module-graphnet.models.task">graphnet.models.task</a>
+</li>
+        <li><a href="api/graphnet.models.task.classification.html#module-graphnet.models.task.classification">graphnet.models.task.classification</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#module-graphnet.models.task.reconstruction">graphnet.models.task.reconstruction</a>
+</li>
+        <li><a href="api/graphnet.models.task.task.html#module-graphnet.models.task.task">graphnet.models.task.task</a>
+</li>
+        <li><a href="api/graphnet.models.utils.html#module-graphnet.models.utils">graphnet.models.utils</a>
 </li>
         <li><a href="api/graphnet.pisa.html#module-graphnet.pisa">graphnet.pisa</a>
 </li>
@@ -987,12 +1821,34 @@ <h2 id="M">M</h2>
         <li><a href="api/graphnet.pisa.plotting.html#module-graphnet.pisa.plotting">graphnet.pisa.plotting</a>
 </li>
         <li><a href="api/graphnet.training.html#module-graphnet.training">graphnet.training</a>
+</li>
+        <li><a href="api/graphnet.training.callbacks.html#module-graphnet.training.callbacks">graphnet.training.callbacks</a>
+</li>
+        <li><a href="api/graphnet.training.labels.html#module-graphnet.training.labels">graphnet.training.labels</a>
+</li>
+        <li><a href="api/graphnet.training.loss_functions.html#module-graphnet.training.loss_functions">graphnet.training.loss_functions</a>
+</li>
+        <li><a href="api/graphnet.training.utils.html#module-graphnet.training.utils">graphnet.training.utils</a>
 </li>
         <li><a href="api/graphnet.training.weight_fitting.html#module-graphnet.training.weight_fitting">graphnet.training.weight_fitting</a>
 </li>
         <li><a href="api/graphnet.utilities.html#module-graphnet.utilities">graphnet.utilities</a>
 </li>
         <li><a href="api/graphnet.utilities.argparse.html#module-graphnet.utilities.argparse">graphnet.utilities.argparse</a>
+</li>
+        <li><a href="api/graphnet.utilities.config.html#module-graphnet.utilities.config">graphnet.utilities.config</a>
+</li>
+        <li><a href="api/graphnet.utilities.config.base_config.html#module-graphnet.utilities.config.base_config">graphnet.utilities.config.base_config</a>
+</li>
+        <li><a href="api/graphnet.utilities.config.configurable.html#module-graphnet.utilities.config.configurable">graphnet.utilities.config.configurable</a>
+</li>
+        <li><a href="api/graphnet.utilities.config.dataset_config.html#module-graphnet.utilities.config.dataset_config">graphnet.utilities.config.dataset_config</a>
+</li>
+        <li><a href="api/graphnet.utilities.config.model_config.html#module-graphnet.utilities.config.model_config">graphnet.utilities.config.model_config</a>
+</li>
+        <li><a href="api/graphnet.utilities.config.parsing.html#module-graphnet.utilities.config.parsing">graphnet.utilities.config.parsing</a>
+</li>
+        <li><a href="api/graphnet.utilities.config.training_config.html#module-graphnet.utilities.config.training_config">graphnet.utilities.config.training_config</a>
 </li>
         <li><a href="api/graphnet.utilities.decorators.html#module-graphnet.utilities.decorators">graphnet.utilities.decorators</a>
 </li>
@@ -1001,9 +1857,17 @@ <h2 id="M">M</h2>
         <li><a href="api/graphnet.utilities.imports.html#module-graphnet.utilities.imports">graphnet.utilities.imports</a>
 </li>
         <li><a href="api/graphnet.utilities.logging.html#module-graphnet.utilities.logging">graphnet.utilities.logging</a>
+</li>
+        <li><a href="api/graphnet.utilities.maths.html#module-graphnet.utilities.maths">graphnet.utilities.maths</a>
 </li>
       </ul></li>
   </ul></td>
+  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.MSELoss">MSELoss (class in graphnet.training.loss_functions)</a>
+</li>
+      <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.MulticlassClassificationTask">MulticlassClassificationTask (class in graphnet.models.task.classification)</a>
+</li>
+  </ul></td>
 </tr></table>
 
 <h2 id="N">N</h2>
@@ -1011,9 +1875,59 @@ <h2 id="N">N</h2>
   <td style="width: 33%; vertical-align: top;"><ul>
       <li><a href="api/graphnet.data.extractors.i3extractor.html#graphnet.data.extractors.i3extractor.I3Extractor.name">name (graphnet.data.extractors.i3extractor.I3Extractor property)</a>
 </li>
+      <li><a href="api/graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN.nb_inputs">nb_inputs (graphnet.models.gnn.gnn.GNN property)</a>
+
+      <ul>
+        <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTask.nb_inputs">(graphnet.models.task.classification.BinaryClassificationTask attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs">(graphnet.models.task.classification.BinaryClassificationTaskLogits attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs">(graphnet.models.task.reconstruction.AzimuthReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs">(graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs">(graphnet.models.task.reconstruction.DirectionReconstructionWithKappa attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs">(graphnet.models.task.reconstruction.EnergyReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs">(graphnet.models.task.reconstruction.EnergyReconstructionWithPower attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs">(graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs">(graphnet.models.task.reconstruction.InelasticityReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs">(graphnet.models.task.reconstruction.PositionReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs">(graphnet.models.task.reconstruction.TimeReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs">(graphnet.models.task.reconstruction.VertexReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs">(graphnet.models.task.reconstruction.ZenithReconstruction attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs">(graphnet.models.task.reconstruction.ZenithReconstructionWithKappa attribute)</a>
+</li>
+        <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.IdentityTask.nb_inputs">(graphnet.models.task.task.IdentityTask property)</a>
+</li>
+        <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task.nb_inputs">(graphnet.models.task.task.Task property)</a>
+</li>
+      </ul></li>
   </ul></td>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN.nb_outputs">nb_outputs (graphnet.models.gnn.gnn.GNN property)</a>
+
+      <ul>
+        <li><a href="api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs">(graphnet.models.graphs.nodes.nodes.NodeDefinition property)</a>
+</li>
+      </ul></li>
       <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.RepeatFilter.nb_repeats_allowed">nb_repeats_allowed (graphnet.utilities.logging.RepeatFilter attribute)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth">node_truth (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table">node_truth_table (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
+</li>
+      <li><a href="api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition">NodeDefinition (class in graphnet.models.graphs.nodes.nodes)</a>
+</li>
+      <li><a href="api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodesAsPulses">NodesAsPulses (class in graphnet.models.graphs.nodes.nodes)</a>
 </li>
   </ul></td>
 </tr></table>
@@ -1021,6 +1935,12 @@ <h2 id="N">N</h2>
 <h2 id="O">O</h2>
 <table style="width: 100%" class="indextable genindextable"><tr>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.on_train_epoch_end">on_train_epoch_end() (graphnet.training.callbacks.ProgressBar method)</a>
+</li>
+  </ul></td>
+  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.on_train_epoch_start">on_train_epoch_start() (graphnet.training.callbacks.ProgressBar method)</a>
+</li>
       <li><a href="api/graphnet.utilities.argparse.html#graphnet.utilities.argparse.Options">Options (class in graphnet.utilities.argparse)</a>
 </li>
   </ul></td>
@@ -1032,21 +1952,69 @@ <h2 id="P">P</h2>
       <li><a href="api/graphnet.data.utilities.random.html#graphnet.data.utilities.random.pairwise_shuffle">pairwise_shuffle() (in module graphnet.data.utilities.random)</a>
 </li>
       <li><a href="api/graphnet.data.parquet.parquet_dataconverter.html#graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter">ParquetDataConverter (class in graphnet.data.parquet.parquet_dataconverter)</a>
+</li>
+      <li><a href="api/graphnet.data.dataset.parquet.parquet_dataset.html#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset">ParquetDataset (class in graphnet.data.dataset.parquet.parquet_dataset)</a>
 </li>
       <li><a href="api/graphnet.data.utilities.parquet_to_sqlite.html#graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter">ParquetToSQLiteConverter (class in graphnet.data.utilities.parquet_to_sqlite)</a>
+</li>
+      <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.parse_graph_definition">parse_graph_definition() (in module graphnet.data.dataset.dataset)</a>
+</li>
+      <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.path">path (graphnet.data.dataset.dataset.Dataset property)</a>
+
+      <ul>
+        <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.path">(graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
+</li>
+      </ul></li>
+      <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.PiecewiseLinearLR">PiecewiseLinearLR (class in graphnet.training.callbacks)</a>
 </li>
       <li><a href="api/graphnet.pisa.plotting.html#graphnet.pisa.plotting.plot_1D_contour">plot_1D_contour() (in module graphnet.pisa.plotting)</a>
 </li>
-  </ul></td>
-  <td style="width: 33%; vertical-align: top;"><ul>
       <li><a href="api/graphnet.pisa.plotting.html#graphnet.pisa.plotting.plot_2D_contour">plot_2D_contour() (in module graphnet.pisa.plotting)</a>
 </li>
       <li><a href="api/graphnet.utilities.argparse.html#graphnet.utilities.argparse.Options.pop_default">pop_default() (graphnet.utilities.argparse.Options method)</a>
+</li>
+  </ul></td>
+  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.PositionReconstruction">PositionReconstruction (class in graphnet.models.task.reconstruction)</a>
+</li>
+      <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.predict">predict() (graphnet.models.model.Model method)</a>
+
+      <ul>
+        <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.predict">(graphnet.models.standard_model.StandardModel method)</a>
+</li>
+      </ul></li>
+      <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.predict_as_dataframe">predict_as_dataframe() (graphnet.models.model.Model method)</a>
+
+      <ul>
+        <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.predict_as_dataframe">(graphnet.models.standard_model.StandardModel method)</a>
+</li>
+      </ul></li>
+      <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.prediction_labels">prediction_labels (graphnet.models.standard_model.StandardModel property)</a>
+</li>
+      <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar">ProgressBar (class in graphnet.training.callbacks)</a>
+</li>
+      <li><a href="api/graphnet.models.detector.prometheus.html#graphnet.models.detector.prometheus.Prometheus">Prometheus (class in graphnet.models.detector.prometheus)</a>
 </li>
       <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.FEATURES.PROMETHEUS">PROMETHEUS (graphnet.data.constants.FEATURES attribute)</a>
 
       <ul>
         <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.TRUTH.PROMETHEUS">(graphnet.data.constants.TRUTH attribute)</a>
+</li>
+      </ul></li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps">pulsemaps (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
+</li>
+  </ul></td>
+</tr></table>
+
+<h2 id="Q">Q</h2>
+<table style="width: 100%" class="indextable genindextable"><tr>
+  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.query_table">query_table() (graphnet.data.dataset.dataset.Dataset method)</a>
+
+      <ul>
+        <li><a href="api/graphnet.data.dataset.parquet.parquet_dataset.html#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table">(graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset method)</a>
+</li>
+        <li><a href="api/graphnet.data.dataset.sqlite.sqlite_dataset.html#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table">(graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset method)</a>
 </li>
       </ul></li>
   </ul></td>
@@ -1055,7 +2023,11 @@ <h2 id="P">P</h2>
 <h2 id="R">R</h2>
 <table style="width: 100%" class="indextable genindextable"><tr>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.RadialEdges">RadialEdges (class in graphnet.models.graphs.edges.edges)</a>
+</li>
       <li><a href="api/graphnet.pisa.plotting.html#graphnet.pisa.plotting.read_entry">read_entry() (in module graphnet.pisa.plotting)</a>
+</li>
+      <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.Coarsening.reduce_options">reduce_options (graphnet.models.coarsening.Coarsening attribute)</a>
 </li>
       <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.RepeatFilter">RepeatFilter (class in graphnet.utilities.logging)</a>
 </li>
@@ -1063,7 +2035,11 @@ <h2 id="R">R</h2>
 </li>
   </ul></td>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito.reset_parameters">reset_parameters() (graphnet.models.components.layers.EdgeConvTito method)</a>
+</li>
       <li><a href="api/graphnet.data.utilities.string_selection_resolver.html#graphnet.data.utilities.string_selection_resolver.StringSelectionResolver.resolve">resolve() (graphnet.data.utilities.string_selection_resolver.StringSelectionResolver method)</a>
+</li>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.RMSELoss">RMSELoss (class in graphnet.training.loss_functions)</a>
 </li>
       <li><a href="api/graphnet.data.utilities.parquet_to_sqlite.html#graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter.run">run() (graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter method)</a>
 </li>
@@ -1075,6 +2051,10 @@ <h2 id="R">R</h2>
 <h2 id="S">S</h2>
 <table style="width: 100%" class="indextable genindextable"><tr>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.save">save() (graphnet.models.model.Model method)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable.save_config">save_config() (graphnet.utilities.config.configurable.Configurable method)</a>
+</li>
       <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.DataConverter.save_data">save_data() (graphnet.data.dataconverter.DataConverter method)</a>
 
       <ul>
@@ -1083,7 +2063,19 @@ <h2 id="S">S</h2>
         <li><a href="api/graphnet.data.sqlite.sqlite_dataconverter.html#graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.save_data">(graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter method)</a>
 </li>
       </ul></li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.save_dataset_config">save_dataset_config() (in module graphnet.utilities.config.dataset_config)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.save_model_config">save_model_config() (in module graphnet.utilities.config.model_config)</a>
+</li>
+      <li><a href="api/graphnet.training.utils.html#graphnet.training.utils.save_results">save_results() (in module graphnet.training.utils)</a>
+</li>
+      <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.save_state_dict">save_state_dict() (graphnet.models.model.Model method)</a>
+</li>
       <li><a href="api/graphnet.data.sqlite.sqlite_utilities.html#graphnet.data.sqlite.sqlite_utilities.save_to_sql">save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.seed">seed (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.selection">selection (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
 </li>
       <li><a href="api/graphnet.data.extractors.utilities.collections.html#graphnet.data.extractors.utilities.collections.serialise">serialise() (in module graphnet.data.extractors.utilities.collections)</a>
 </li>
@@ -1095,15 +2087,37 @@ <h2 id="S">S</h2>
       </ul></li>
   </ul></td>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs">set_number_of_inputs() (graphnet.models.graphs.nodes.nodes.NodeDefinition method)</a>
+</li>
       <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.setLevel">setLevel() (graphnet.utilities.logging.Logger method)</a>
+</li>
+      <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.shared_step">shared_step() (graphnet.models.standard_model.StandardModel method)</a>
 </li>
       <li><a href="api/graphnet.data.sqlite.sqlite_dataconverter.html#graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter">SQLiteDataConverter (class in graphnet.data.sqlite.sqlite_dataconverter)</a>
+</li>
+      <li><a href="api/graphnet.data.dataset.sqlite.sqlite_dataset.html#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset">SQLiteDataset (class in graphnet.data.dataset.sqlite.sqlite_dataset)</a>
+</li>
+      <li><a href="api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html#graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed">SQLiteDatasetPerturbed (class in graphnet.data.dataset.sqlite.sqlite_dataset_perturbed)</a>
 </li>
       <li><a href="api/graphnet.utilities.argparse.html#graphnet.utilities.argparse.ArgumentParser.standard_arguments">standard_arguments (graphnet.utilities.argparse.ArgumentParser attribute)</a>
+</li>
+      <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel">StandardModel (class in graphnet.models.standard_model)</a>
+</li>
+      <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.std_pool">std_pool() (in module graphnet.models.components.pool)</a>
+</li>
+      <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.std_pool_x">std_pool_x() (in module graphnet.models.components.pool)</a>
 </li>
       <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.stream_handlers">stream_handlers (graphnet.utilities.logging.Logger property)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.string_selection">string_selection (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
 </li>
       <li><a href="api/graphnet.data.utilities.string_selection_resolver.html#graphnet.data.utilities.string_selection_resolver.StringSelectionResolver">StringSelectionResolver (class in graphnet.data.utilities.string_selection_resolver)</a>
+</li>
+      <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool">sum_pool() (in module graphnet.models.components.pool)</a>
+</li>
+      <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool_and_distribute">sum_pool_and_distribute() (in module graphnet.models.components.pool)</a>
+</li>
+      <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool_x">sum_pool_x() (in module graphnet.models.components.pool)</a>
 </li>
   </ul></td>
 </tr></table>
@@ -1111,18 +2125,46 @@ <h2 id="S">S</h2>
 <h2 id="T">T</h2>
 <table style="width: 100%" class="indextable genindextable"><tr>
   <td style="width: 33%; vertical-align: top;"><ul>
-      <li><a href="api/graphnet.data.extractors.utilities.collections.html#graphnet.data.extractors.utilities.collections.transpose_list_of_dicts">transpose_list_of_dicts() (in module graphnet.data.extractors.utilities.collections)</a>
+      <li><a href="api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig.target">target (graphnet.utilities.config.training_config.TrainingConfig attribute)</a>
+</li>
+      <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.target_labels">target_labels (graphnet.models.standard_model.StandardModel property)</a>
+</li>
+      <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task">Task (class in graphnet.models.task.task)</a>
+</li>
+      <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.TimeReconstruction">TimeReconstruction (class in graphnet.models.task.reconstruction)</a>
+</li>
+      <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.train">train() (graphnet.models.standard_model.StandardModel method)</a>
+</li>
+      <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task.train_eval">train_eval() (graphnet.models.task.task.Task method)</a>
+</li>
+      <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.training_step">training_step() (graphnet.models.standard_model.StandardModel method)</a>
 </li>
   </ul></td>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig">TrainingConfig (class in graphnet.utilities.config.training_config)</a>
+</li>
+      <li><a href="api/graphnet.data.extractors.utilities.collections.html#graphnet.data.extractors.utilities.collections.transpose_list_of_dicts">transpose_list_of_dicts() (in module graphnet.data.extractors.utilities.collections)</a>
+</li>
+      <li><a href="api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.traverse_and_apply">traverse_and_apply() (in module graphnet.utilities.config.parsing)</a>
+</li>
       <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.TRUTH">TRUTH (class in graphnet.data.constants)</a>
 </li>
+      <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.truth">truth (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
+</li>
+      <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.truth_table">truth_table (graphnet.data.dataset.dataset.Dataset property)</a>
+
+      <ul>
+        <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.truth_table">(graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a>
+</li>
+      </ul></li>
   </ul></td>
 </tr></table>
 
 <h2 id="U">U</h2>
 <table style="width: 100%" class="indextable genindextable"><tr>
   <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.unbatch_edge_index">unbatch_edge_index() (in module graphnet.models.coarsening)</a>
+</li>
       <li><a href="api/graphnet.training.weight_fitting.html#graphnet.training.weight_fitting.Uniform">Uniform (class in graphnet.training.weight_fitting)</a>
 </li>
   </ul></td>
@@ -1136,6 +2178,24 @@ <h2 id="U">U</h2>
   </ul></td>
 </tr></table>
 
+<h2 id="V">V</h2>
+<table style="width: 100%" class="indextable genindextable"><tr>
+  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.validation_step">validation_step() (graphnet.models.standard_model.StandardModel method)</a>
+</li>
+      <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.VertexReconstruction">VertexReconstruction (class in graphnet.models.task.reconstruction)</a>
+</li>
+  </ul></td>
+  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisher2DLoss">VonMisesFisher2DLoss (class in graphnet.training.loss_functions)</a>
+</li>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisher3DLoss">VonMisesFisher3DLoss (class in graphnet.training.loss_functions)</a>
+</li>
+      <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss">VonMisesFisherLoss (class in graphnet.training.loss_functions)</a>
+</li>
+  </ul></td>
+</tr></table>
+
 <h2 id="W">W</h2>
 <table style="width: 100%" class="indextable genindextable"><tr>
   <td style="width: 33%; vertical-align: top;"><ul>
@@ -1156,6 +2216,18 @@ <h2 id="W">W</h2>
   </ul></td>
 </tr></table>
 
+<h2 id="Z">Z</h2>
+<table style="width: 100%" class="indextable genindextable"><tr>
+  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstruction">ZenithReconstruction (class in graphnet.models.task.reconstruction)</a>
+</li>
+  </ul></td>
+  <td style="width: 33%; vertical-align: top;"><ul>
+      <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa">ZenithReconstructionWithKappa (class in graphnet.models.task.reconstruction)</a>
+</li>
+  </ul></td>
+</tr></table>
+
 
 
           </article>
@@ -1180,7 +2252,7 @@ <h2 id="W">W</h2>
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/index.html b/index.html
index 7e48c3ed6..fe3017fed 100644
--- a/index.html
+++ b/index.html
@@ -413,7 +413,7 @@ <h2 id="acknowledgements">Acknowledgements<a class="headerlink" href="#acknowled
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/install.html b/install.html
index fa6e7827c..a9b5b61a3 100644
--- a/install.html
+++ b/install.html
@@ -505,7 +505,7 @@ <h2 id="running-in-docker">Running in Docker<a class="headerlink" href="#running
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/objects.inv b/objects.inv
index 3ef123ab7aaf3155a34935402f71d0a18a79a679..e2a8e9d3320ab422b2aefc6b85dc595c1a5818d6 100644
GIT binary patch
delta 6145
zcmV+c82;zL8^SP<cYj^oa^p6ZzSmQ*YUVPn-NaQ}TbZkl?e2JEJ09CTsoEPBL_!i`
z{xtyFR-R^_WuI*EM~ak4;s7EkUvygl&-uOs;NX`4xS^=nW|#-X?&ma*%LE7i8K8pv
z$HqKZZh4ygO+Whd_u$i_eh}pu<0#{-U$++NM~M+f{Q&amuzxPZjJ6yUZ~5i=^YgdI
z>n8~uC5W-$M>{SAaCvoo{c!pF<F!OaoKvzYIUWuD=JNXT+b`Gu{a@3-|MvOo*V~Cq
ze|Y@;?)v%H>u*of;D3Ag`uO>3T;rb~zdiptnvW>7X&gjDPvadLdji+cz!Uj}1|QEU
z<bOiw6;qC>rGGtY$y(Z^w(OM%T{c<z9j}AL34Mv*aU`=trUmnD9#2EFCJ7E%xn7fZ
zNmbCi5D#|KkZmvzQ&fcOGK)CLv#Ar8V!FYhOd^|_ilJ8=)?MnTL|=scPkNv*s=pcB
z<tB<%9@(qm2_)oVI*tfMTO8&kFG{{I?b^vkVj&sf;eRH{S11Wd7UTEvNRk}Ipk_-8
z0@^GiD1G#{1|RmGCN>asxu$ulyq<}gEt0w>8-V|VnW8*N5EnH#1LZXsH_k)06>XmO
z#}RPZ&^W#+1t>aZ0&q+(?B!*ilxcQH7$e!{`aQx$)g<FO@FSGfPIt9W(y$-;V=qx0
zhY4EY1b^U!C(sB-kAz!h2-H2FK%}Y{xLTb+R3N?ub^j-Dd@nKGg&eJtiF1w9l5h7!
zQ>5uSV>lJ~Q#!S5bwX47A+3}LMS&>8p^^z>yp}zjX#&P%6D*?ap^cf51~z7uwD}vl
zHJm-#r*SVIblAx^1}!<&1dpocX%G9J5Ys5ybAMk812fwQ>pzag+^KsC@AB!Nw<73N
z!UIN}#O_hJ<Czl5mbIiS9G_AYdz4cOrGe7QpogZ3p5=6C;jwA>ol_L?oZ41H_kxzh
zHU!B<9e|@Zl)RXhHy77U*8Z#;BK2WgW+=6H-J`S0a)zuGj7=;C;t8(TL*~ooJ?|Vz
zUw>mHY7q8v<U?H-e?HQ^x%i@^e7w8clG9DNiST8)!g_7^W7u+qadDYb{BbP$RtT8*
z90om<3B&g}Vfa~2F1{paGp|h>oMB2LuiA5{>r#7;bZ;)c>L{Pu(`|{xgzaucNlbB+
zd!@cZ-6aFM)I}bkMEoStLxM8=s}gw>B7Z)4(rcf^u#Bf=o~hLi=9Mv>04`b64RCXD
z|9mInZ)Lnsh78zcF-S~SM42Rheaw-vOI92!i#Ef=`Q$;5F*Yj*Lw;#*DmtaRr6|&T
z^Yqy(y~hSvn8mDz`85b_O_$UjXx?1>^V{{)?>_16vcW=5F{geBZfm+E_(1dK;(t*_
z_$0W?1`9bW5|ZJx@T<<p+Aaz2(7w5NQjvEL*Eg3wS<q*Z#T2!g>6Ha-O_$^!Xx?0?
zQIAieyKJzLm#JTJ8+uNOZR&}ne!KHZ=n?xMZOKVOIKd991+@x1aeI}hp6hurHq=jK
zU#=6xInKfiGmhghCMq7$opsq6W`CK7NlXk%Q6dKS=QBY(&RCH%911&ln6FjFuz3wy
zQ<VC7q}KFt4Wp_4goK4@o^c!@n&f_IZ(8DMLY8~Gf<L?58k48t3*NC1WpP-@8K-ni
z%w{u%?+U--SBgHRuN7Amqn9w+MG1*QF%v2&gS>{anc`$agk-f_{VD2A$bWN@{S=4A
zj&ILmhbsI~m;PC;Q5&{!y23PkMWT^hCBE+HYk{bbt2`HZqWj_VaO12@XQw8Ss0k~4
zri7bC;JIn#j|oU&ub^SCp1WR+Lbsj{>Yl>eeB%q;p$f2N3?mMAm=5p$o69Kp&d|9m
zf^H!7vxgK|&OUAKX~b-vSbvG#Wcp3(Lj!)m;acd&u%OKYpiKkE?{J|NbU@Rev++I}
zSU0U^oXo#O*-no$<c4fD>diu&7JO&%zZpz(YB6T82i{oBU<|aK7}*qc0)v*3ZgAA(
z!rD)dY=Rt>L?Ng69L#!1F6zqk+|ae59xbPKYOf_~8H%>Ro$Z6vaDUE7O#DPdy?1GD
zPBqh36-rF9gydW{$cHN={0YAk23X}8n^q@&UUFJyGK<3~k7q0E35yn5NukwpFINrh
zwB7a*(SW8uXEjEB&UGII9Jg3)8OI3eOqTBNBul7h7xEOG?^+z3*+#rTguW5E&I3G(
z(oHzTBzkO;Z9<04a({40G0T%z>nIq208VW`!jpRPR6TfX*aq=1i)o|Uv@`->r*kq6
z7fHTLaW=F*QOgWu{9cA!XaTmYo;FSrSQ&b?Pvc&GmEn8*d~<PE8`E9VU|ZacxVgB=
z)|lcf!e_NYZ7I7<Fv>8U#t`uUCj<SHe%@cE=6RIMQzDsgXMZcyUep0~`MH|kUACAy
zv!aK{!Y~8ya-GvRMC1Mz;$#Mj<1)fwUevE~duPU;TqWaqdiDF=WcekqKWFhXN%0?I
zqLIH%Ni-RDnXM2PQzaa4uzkMd#E}+xh6is{IT8ibv|2PlJ4~IxwT3&#@@kh|i~Ny8
zHhcDW(mC@uRDZ$K2%qK4`l(Z;NajgWm4VNM=aVL@Ni=`bmMLZoZSahYV$>%O4<Wf+
z#;d9~K6R1jiA$}P@-B5+$!}<07U60KYV10w;2e7IgPgFLS=d~3Mark7OGy@1MKv`?
z_&y~I<MHGvsW>K1NtsTc62qaE8)L$%LE9XZqh?+#8h^3jZJ7sx)<&ilbcgomSY9>A
zhB~_*E;o{;Iq*}xJQ!^MM`8qatv+mKTME3Ki~_uvDh2zQxIrp64!=_JCbS-|8=Yz9
z1pbCLn`DOMX`Y!AXFK#aHqq3{D`+kuamnD-danbo*$thC79}Vb4;kI5{PGdPyU8+Z
z)(IJFR)0wu`MP-UV%TF@7u=<so!k#xThRv(#0B?0V1js;$+Z?xh8S5;hE%d<4FQrN
zEqEZVQ--(@gnyEZahGh-8ZMb*G)z+d=jQ(D*&&~reC|X(8Rx?IqWX@In`IAhCN}h(
zva+dX()oPz{JTqj9#MO8dSxw`9NHQ}TSWQq>wo<{B%^9b3zky`fI+iWNV1~T3%Gr^
zmv{o}`z^;omSKBeuf^;oq{XC*fH9S@6S5D!FN<2lyJ}sCHt2UfoL*#b$2^g>j1m%K
zq?!aaaT0&Kzn_EV5yb>0Q!#bjlLzs|Lp`WN9_gv~=gw}~O@9KftB3j-z?`y|Ic2JH
zNPm?omprt4NG3vHyOl81Y%T18@5{2h7ymo^+7YK0nX7}1#Ol_^j!G@x5@n@2XP2)V
zLx`zU&ptD?4~30czUq(FOVuB%V`{0p3USE92A~CLxhCA2yDo4*cXfe;-2CefEANyY
z8!u#BZm5>$)yIS`D4qy1M67*?VJv*P-G6v&yb*I(enBG^ej36mXj|M`QxC+SOb}Fc
zyQo|_)LX1lWL1(gd)qk@qpY8DJyGt-O37c4H~gb8LtC6t$VjzcavnNp#UGQ*tw|V@
zS?&$icbx$S9itGlEfUnCUBSgXDdmXDdj#f9C8vmFGo?>bsombQW$7ytb9bbiXMa>u
zj@XN}V~|4D?m!8tHlM6Xgh2D<4jmt!<u63uCv@40d^RZ?>V|sxPbFT8$uBAs9(GoG
z+b-Ve%{r=~=QF9bf_gSHZu1S{>}>v`=|8B~92oa;e8(6zuLs?UJRmA=j`l`g6Kl?4
zgB#Sg4y-{v(E0oqNlU(cJYu(|u78XCB-Rls#dJ3L)rIlxajycS+wceB+uu-8p!2v%
zP5%-7<=D6n;y8Q5=JlW}LaXzZ=kXz%0cZ9iuzB<Qk=3+$J?T2bbhDeeqv|3(v&RgY
zG{2X0nsa{7kY4(bzhQb77pd7lpsyU;_A#7hZ`a50o^M%%DMuvJ_hy~NhkqJ2d{9s7
zvhd^hQ`5S6z3L{z!ePYxk?^_utbXb<`w+me`CVtgsChl+A!p=xv^BZ6bkm;MQ!GZ!
z?<GA3&FdlZa)p__<CyZyo-tt1{9e&v%)B1)FL`Iy`^-MkMS5nh7&K{qzvwh)UeEYn
zoRQg=b-PH<>=}b5&F>kV=6{^uGv-^48)W~0zVZ<k;)d;-*n8mn2H=|q)PFZI&bSzK
zgI}lZ@!`P39%b3BayqN){uqBom*YuJETb2XGS_KGpF8|)r9NM2D<<gXKzC&LPC$3;
zNqtlA(F*Y^O2#l<-t%cchoPjbZJm~iP!g4i-svkhYen*mx@e^#*MGM(cpwDtyT5q2
z&C|Ttl4Lg){!5mH`k;@x9Pb2J`7O}?yp>P;^He=u6V3;3t87Zau;kfxo|p8C2oAh0
zI0gnH%GmJ7Snl|jk^Me;8pBT`GyH+cX3Iy+lqR6E_)SHyam>~GwZpQAkukW>MIop~
zl0*Eg60<m3Ehc}2KYv~|3u@5HHogt`rwM>5MsX-c8N;7&v|(1eX#2CXlIYLLT0*pT
z$qGd;41iY4y5KG!2#&D4zWbJB7}49uUjOeSa%NCpV@n3At~V3SHxx7W3!>wvVvlXU
zCMSBeh)tS^r94(w4eLVwJx*t!<Cx9IrA9L|ku~<3PtPbx^nXyT$mJEgSv)JB7<y4Y
z7UO=1Z)t(FAfWj%g7uyodu?q(4`SZRKV#!9U+`{dIPU^0=73?<w}3LWr2xs?wu0GZ
zOhxRq=-cvQcWb7hmK^Hf1rg3jRp0$^WhNUF^WEl~%e&tuCs@iVIXv|pjmkR79)~Gs
z+sVs@u`kt_ZGRq`IWq6=>3ZUfdz_x)UwL*X27v4fVOvapx|)^lVqUu0$%MovjMA4W
zy|AA(M1|-T-_Jyj-%-Rz<+dz|wdDZ4v01gs6SZcw9{lcg7x)mWoQLceec71MDSFZ0
z{F65$46iYGLF>V4V@N&?m%4dCjFHQPkCZ)nAt$u1&42YM06gYE;g>&iOtahk&EBYV
zQGqzVWkhsk>4XfR6}E9{hHpoKa{<1P2lzN%;{%m+Q^;?qK;SA_7u@A%x+&#HmL6oN
zKezi>+~D^|vXO}&el<H1>l@)}mfJzyUEaRDzwf_NG-@syT)CDjx9m=5`Lf81(7d!8
zjihm-kAKC3wIsxUjimapr!#&-%a3dMi}-z6(^|^9kJX26DevqUyeMQM-ndMQ*-3|Q
zZkmWxy{T(B6`*yRfPYPJ60_O(<>(M4N!>Wg(N>8Ehe4;(R~HFY-$4de-lm#E?QN-D
z9#F$7@G-L|WuryxTO^#29CXXtg#2olJvqOeB7evW2Aw;HwxWZF9VlMupldJDMeon+
z(Oryefcf}<V@o!&5nY$qp3NBb@kH?`Dbw>BqSiUgjOK%xQ|%j5z0ilTidX%`*{ty&
zn@nPe&md)FL*%urA*UtZeh@Q0fR_~L%(<aHE|{X?ysp78j?ZjBvVUcLJ}VBs%6?X3
z&VOH1&sRy+!1s7u3Z7rN8*tcN0EfU$HHQIkOHEzb=sNg4CiY`ox-Kt{&hI?z9*tV%
zbeNedUC)qq|1rwQCisj_PBKAN@WbSTtG6q)TzDOPtJJ6AP^Bx4YI#UXX0i#APUg5g
z)MAIzR#J0u5_Uj2C!s0Gi5U~ZbaAqBrGFVGPb8<JZcv8R*MLz`44=U`g|WVv*ty#~
z&~&IKhnmkO?ArfzT5t;Wvi+tU-VTm$2}U&8g^cG#A%(;og;a;C&S{o;?|V)Y{5#C4
zZ1L$i$G1fv8l29G7)LopJg1<S^?PYZ!>}fzQQd)nji>{v_$yjbauKR9tRZc&SbyNR
zV)%;RInJ1wE{+Q?JG>}X#!munGG;t%k`_7Tpu#kwkix7*MQ)p~7K>IfdO;hY=`Ah*
zi9C!nErX~X(RMltiochH;ySy=V@UGN1~`^BqL31<L~q4k(B7IBHVAYXQ4t*to$x&g
zKE|tZGZ8kXG^Z0`%W55N#b7frpMR&Zzi9PfnVf)R>wKadVwm6JS3H&DFitT_r!Rsx
zMC${T<EO%)hxy>~e?N*xTwuh%kc5xkHZ-ASRWJ-wSud|KpeoGdv@(AKue(DSQ^dER
zzE=rC+GQ<>3l=8i1=x!;z$GeL<s}aokAIHyDEK7QWQ2Mie3m~KO`Mis6MtXJKissW
zm+B50Bv&u{^lAx>PzT#Ma`M*;As=dO`RiB_U`1QZYKr<#Jb#sUfYD*pmgwHhI;scV
zOm(f$T5gZ4!#$vEpa7dfdpHs3$8y7T#4Zs1UP7^ru*Nn4J9hOqbV^5$a*I2(r;mLX
z^%27;wX9~)zpcRLk6Y3@Rezhb;B*~egtEIV=}I)G{)X&9x=h+X?J!Qv0V#^q1bL8u
zFQ+N?;X%+Pxu^qRf`*PJYed-C%w8lP(nHmX`!L>S!UkuUis_4o35Tkc35W4E6Slh*
zC9%Be%*Toj?aGjjM>bnBo|bv04?_D`)2UzC)A`V5k!XtKr;WUv<bO!FvgGK2)e13H
z_p+d=SF<-GZOoz_W?qibR;;+&u{LswIrSoMD^}#~SQ|Mj5|ZJx@GAOa<;sZTCpJUW
z4h0`0+KQC{?N}>$nR-z-)N0~J+%WC_)D)DleTXWkgFqg2#uA&8knWrf?P}CpM?fu5
zUtRSf*NU1~|AtcXU4P-lsPr_U`q{w;8g~N6n_WXi;Iv^-^?0l~G@U3&fo09L5E?0i
zZ9*Mzkbx^`z<NB$1Ln<u;?c-(C}Y|<z^dD$)FK|L#6TK{W??T{c?8cxNGS~{wmqzl
z9%QE|i}MspsiifK{yhkOxc5&UZ_blfo2vi-Ov&{<KClWGNq@eR8?yR&YY{Wmt;2m#
zqSvW5F91Mm;i`8sK|8vM9LlpHt!8ri^QED9xEaKW^*~vZA4D!XE~e~dG51y~9R0gX
zUG#|Z7^qr_wIJ;!tFE&O_W-9wp5dWos3XY=xC&c<mkHWo>VmHoG>!gkqbTw`aUhW5
zrZEw}I>BSi@qgeT?Le;bLme7~bjS7~sb1niIgGrizdA6?P|#d1dE=6SN%krx`|EHg
z)h5eq&_cx2!Yu2<kc+6VL+ixHqV;O44D{m0y9~!048b*`rUZ4sh8zA#GR7UywL*(S
zM_HL(&}Nh1kQ1J}@HN!htc?e~+`!Vcm#_?5eNam^T7PUd0JfBgfI=!+vot4vT5&zl
zT{w=8<H3jy)f!pVh_w<k7dmix$OAx7WF@PC)lVy~2RevD{jtHSKaBxECsKX81_-h(
zT@59haf`l7%5Bi~9p-`jwE%vB9E95O(ej`jw*Y=v0z}vGEsCHWw-~(19Yg}iE`7<B
z1#Ev6l7E_X9k`pi-MkN>zT7;JzZSh$QV#<(ujC0#V9|dU^DtQJ#^#_3EjsT6>P2eZ
z+#FP)h2?HlJ6!u3)}Rh9hHe?P!!~Xs4fO8?<a#s!@bPtOK|5|SaJ`ltf7tc<?&7k-
zc3kxowBr`tTLpTdVT%(TcR$$yTleq;RcJA|SAP}&+t@@K=-&(Q6(Df<!+S%5M#KVs
zD+HYI=q`z%8MAOXg$F0?IK3CN;}-DqL2%lmQ$;~DW&t{O0cUHT(+Esp;k08L&egTm
zJ2;43ST}P6_zre{2mA11cvCR|ySe>1FoDJJMpgi7bIWUB0*m1-SOC=KX05;k7DL;m
z0DqYKZB>E#sqlK1^;!#fj7AH29h%u;))3vtYRStwyujrroQf*0ms?dY*WSDJx4#Gf
z_wU6z$={-_ycE4(>{rnfd`i>?%KR16SAySyd>tV9hD7bheaq>K1PV%3g}%zP-{a?>
zgZB0;oXIwa1!l!vipF^=ocI$fa3n{Ts(+OT#I%=WaVf@ryHC|n`3DijF->tykjMkI
z>0MsUh=h{}i9mj%LIUzFVG89R!Qrx>+OEf>R^0|l4XM2<Nut!{x>SPSgRl4ZW(m-w
zi4pnWhoIQ*7>U@YoDwVyiW^IB4aR!Fl%fpD>7|}3)X|a?zZaMiwNPvv>>ny#OnUBw
zrMn-3t|jG08nj&sQkY?kl@InGaU$7K1|JO0MbQ&6P+4r`s$?t@-tvMfRy)^?{JGeR
z86Rv&=-q&FjYHVkmQhg^jtVS<$sL@<?o;(Zi{I&`Va?unYEgUG@vTT;OPgUtqitJJ
T$7!7>{ro``+2a2J@}KaTsY(JS

delta 3434
zcmV-w4VCi3Fu)s-cYj>la_cA-eb-k|HT~G}On0iLrn;ZfIXUOJ=3LUb)I6%Ngl#&q
zPy<ThznO2DFB@)>kt_r#vePe#fW6n+zy=}Mctc^eEfELR{+Aue>jHs)09fI_Z%n~r
z%Xh`!bkW5R;38=YGEy=QOU~T7t%w^Xha9>AWa+qWh#6hEP=CC|$D6B%$Ge+*37i#>
zG4PWgHw19;;pXP{`up9DM24JGys9~x4SjiibN%?+&A<M68u))-efsp}M5f=~egAs%
z@cHKP{xtZHx1a8=KFn+U!`<V<=h=Kjp`FG-H1uh_Lt~%7H8k*v{6d4z=M?gf5PC+G
zBkGxtI<lTQsedc`L8Pv27PsRKkVK&`&@0MhbI7_Nz9so-Xx6wuDXZ6O{3@v`N-9yX
zKMmOi@pK2PbX}Jj$D}-U<#vc}P%5j)PEEz&GfLZ6>a0W`h5h$>p)jkz8QjGt%T*is
zhv6GY@a5?^GMH^qN@`x!{LtI2lg-3}OU%<vK~}IxaetYk*Lg_LE2l8yg!-0I2SLxA
zWeMWTHW7|)DNSNpEH7_tR=3lRkxpS-lyHZ<)7b2+4Dp~zW^UqZQWU72NYB^I2i|j3
zeT^WmDVpTSecc#;KGI!Ye$-i0ynEP^x0|rc&~?2+dQth;u;l|n)it5$uVc|;C1CRR
zFzB`}7=QXkFhdWt=G%`2+?>~@4Jr}E*`(3a*NquHN4m?)PdaPL=oz-eW5Ra7qBy50
zBa>=BpdQlzL+YXpU?CoHc3Z#_eO4lOLS#ytIPLQoR`Oj<$^x#CaLp+VGYAmVG{XSP
z%Wn@~MgA9=KczuN?D7~?K`Q#BnD&l2QjTegV}E5aXPCSxZQwA*X9Z?(THS5MsCsu4
zMV*)TSCi^{Y=DRP&l<656GB%rrnWuJ<>l{>H}~JC)OW}R4>?7gPOETNGp52l&E@5t
z%$QQ)Asam8tSWGccGa}ue5@T)@d54S<-N-MdV90Hp3(#^i#(=qx!{=81YOOT>i0C4
zmw#&2Go{i)Hu%Wv-L&djdQp|N^+Z)4zfP*?5&I}@&2fP_MiILOtx9<14k}sutv89$
zQhy`+VqHMaQJI#Aag?VyR{4<bz56g>mQ65;iA5<a#Nz(_Oc2E-s|Z7>u!E;$tr~`%
z*Pu0pyJ;S&HK(|S(NzCZ%+ehxIm#d{$bYoD+m=i;p(K-CVLH3qQzyIh3GG=5%RH^*
z7pHVg#Li|4-&fPRUnx%MeXY2n2tK9RJ}YpRiZ7v>GP4;LUv0~jC>t^;hu!KAF>X>$
zMDtUWR(rmE7duqvr@Hsw)f#ougLW%K(`P6q$yG7c9j96##^Z_*(I)ym{65^MtbccB
zr^cBW2`ltY6*rr}=cbhpCm<Ey)8dnhd*(JH<(bGDDf2&#IkO(seGBj6R_8_>Z3oyg
zhY|gw>G5O3a~TzWq<k)mLQkG2*~9)6qXc=JQww>#&<N>&QNBN-Ih}KH|N9q_8R7O`
zGv3U<faPAl!ex+zT2Pj$`0(d@Pk+~*!L+Escm{jI$Ke@_2_1qno1&tkXr1Y?tk&(`
zSK!Ph$QOQ=a*E!AS+91*SUEj+=(dYPD{7uPXo*n<vu$9s$tZO=XC}U>L<W~j_V?ga
zAF*mciSHL7c`h4d;R*_WLa&$wY@p0dt1EwAa$1+NiPMbaXY1<|c5Ub&4SyY$2f1d=
zr0w>vh(<K0b2iJr2cq9Wf$J8pE%O*5oyqh1dy^$p%nM`}-0oUDxUr3R1u=cWatmcE
zmO23w;4tABlXzp}a+5M-c3M-47%85;tKa~H;M%$o?p4WsQ*dY57UE$Z(`L14VFZGm
zF3LDm6=c6d<=CfxE8}i>d4E3bMv$fBx?<hUxv<)UVWh&a7GlhDQqzwjInuSqz=4P|
zH~lFWZxxm4?w_|Y{5&(me-Pem`HYQX11gJOR0f;JTv7Km-lTOUyj&aXny5lhg(rvj
zStI6gv{p<$q(6PIJ8IO*F1`===Mw;nm1v%dAw7Pq=)>d`;&QY(8-EFROSTf)ZBp+{
z;82KWSp4<}LA3lrUYx^Vlq!hhvBYaacPC^B8qf36$Wi7V-QipT3)z4*to0p9H)6%b
zPzYQj8-u&JqMKU22ybP|-P$geIQWIU3YJ5fr_kiH(JiT&SYI&Tro+_3{XLcUd)0<p
z%-V_uSH3y8@Wj(uzJI7lmF8#->w?cllB0}JIDSuNNIJxi)i&Zj2-t{5po;%xo|^kl
zqo4K;77zRuvJ|~?RQk@or@>>E+84xT2e&SU&4Z116`?$+GmU6SXSSm7aKKu@><MlH
zP4938kjN#*X&J=m$hN1Wpy<yUi%+bkx#nXi$YwLT5;_tM>3`uy^o#f(^bYG_2Z1g!
zDzY<Jgv^Q);B&O9Hz&eIloEO(Y<0M_6>ru8>FM2dfA38!Cnvz=nw+SG7~)^hGdh*y
zFfP7(cBk)x7f9RBqWNoK&{I-oXl{x-R3Rvq`30YSZm^-{P-srXPh|a~T%}Z{Sqh7S
zyr4XA`Fz9(SAV4Dfbsm7oMhlaD9Q{~4zA>5Vk30~`#*F8u5Opf@K4w{^}4?b0or78
zv<v8jXn*Dg$fsoFmnX&{!cFP@RK+Oj@j)f%WI50a0`m}PEqB1xX$L49=|EfP04EY|
ztT@a_>?Ud}2-^<c0w&?d?u5~>SgxLOhdcGx&zwO8QGZ(JD?mScf$ekh;HnL(+O)y-
z>S3hvHB{4;n5L#1asU}J=}?-C1r3h#_kN!N9MVbXefVI4{vCQ?Wh9Z;h$Bb`Zy<x?
z{3P-^Z-fbhH-W)%eiDCsHo^+m-CHnNluWs}8ext5;wx|zCb8E@Bdi&|f(ix-l9<b#
zF%}$M3V#K?%p~fXWd!xW{Y=o$O`@(mMvxEOl?46VB>J*pgaMA5ieRWDiMjw7VZ^}A
zLD0`lA}-~{Q1|chfo@(B7Lp#!*gXUQZ8C|!?=^w5e_IS#8A;S-qY32JEhV6zlbCq_
z35<t(0zf?`&JgQy@WA!v_-%n826v=47%UpH0)JOXj$99+kuumOmAC@B#=(LXtYgvS
zA?VIXP%u=I1c#=@koQJ!KQ|ePWf`Z5m=?xD+yXg<W4XXlcH)no3Z=ASt$=<`CI$Y4
z@{LCh0cUaIh!=A6c0-E5A?<<^Zymj@%|bw1;pXjRl74iPIo4)lT6T5Z?P4i5rK#EB
zkAGtADQogxX4eLhC)USK7G11y^z)Fqcq7VVOtr)5K{`k_ug)gz0Nz!kL}S}fN0JS2
zllA~F3b;o!245>^8vX2}s0b+{2&A}~OvKY5c#b$4U8MEo8b8#jQOJ1gDJ0F|SSW{)
zH{(|)#u+L~<W@4OSun{##Wa7L9;DiorGF1vh?r4WmM#priTXUYPkb!e?zYN=6F1+N
zDBqwETqA0FP$zu2=^wacJOW)SbU1XDm+1*^HU&yK<|KyCQtPr>1y1h3-u3seOnZG$
zOIb5E8v$F$N`yixS+g}KAHBE}&|^5%KN%(?e^^$YiH8yO)vTtna=pYXhAy~2N`H|G
zMOLyJSUq}iC!hy$sDD0K_0L)W3?jABVL^~x>1HWu=M9o@$ew)AQpS9$iYF(`LY-pb
zHO^p&2M(y13AmHzrjxs(=E`CyU01~{`8#=Ty5N=yXfF9N;PpKvbol7%xtqER;6iAx
z0s#4xI0LxH0ql@cU=oQta(NuA6Mx?fnp83v)9OU(#5sc|H4L)IU{S1K_Qd>ft*G(9
zCh0)d``><8E!V7RS4Ru|gCJq7{T=%d;Euc@=;S$w#vQLK=5#$*EV&nUAXFK2^OE6U
z0Vj0W=EOkkvLCP$M+%yh3s_&E2!Pds?NU=bk3%lhea={TOG#5%86H0UuYacM%5E^B
ztz}nE=XfZagVWercJ(^v`C)ZcGi@}jt!2H$!ziT}jSUu;Lc0&f+{2_%X$D+$&3cO6
zY1UIeE2P@b*NdnIhjcO?tQ&xBhxG!JNX7$115oWyQeYCvcxY1qsvQIcOu~h_gM1re
zsgg5(DIXP{V_UDYfyZc;$bWEt6S-xH9&)whb(>C5eSvmji0kcE^_J_#ZvE{C@W1~i
z>w>&wTS)nc*k|!3xF}R$B6&vi8KW0K)&R;A5*6qALg<qO0<{`KpJd&?p@&~UH#`fK
z`WM2?j=L1i$xb-&1*=ddSC(p)0OGS3mw7GLe*24Nsr;kJ;)w20j(?$O0~JgzuNFb!
zBtinn8(k8RZ!uFSzk>d@p9+IxQmdT<rH0gA)wodVa$l-Iui(?SZ)OY7)QMqP@FS?U
zdxkT1K`2JTpxoMevl!b2(+-wUeqQQNg*IDqqSp#htTu{`gZ--V#mAknbpIn5+ETXa
zpzTtS!VF`qEWqU&QDY>dBXN1uz`ez}=z1&`D#=FfO3L4#=y*YstDS2pACrTa@j**M
zM+3?|4q<0kMrAcPDv=122RN(!MN^=~zv``F&2E)?)Nb|f;dQhbf;8K96?LB0x!2{l
MqRW#116ZEXq$@eMJpcdz

diff --git a/py-modindex.html b/py-modindex.html
index 058464641..c3e8451b2 100644
--- a/py-modindex.html
+++ b/py-modindex.html
@@ -360,6 +360,46 @@ <h1>Python Module Index</h1>
        <td>&#160;&#160;&#160;
        <a href="api/graphnet.data.dataconverter.html#module-graphnet.data.dataconverter"><code class="xref">graphnet.data.dataconverter</code></a></td><td>
        <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.data.dataloader.html#module-graphnet.data.dataloader"><code class="xref">graphnet.data.dataloader</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.data.dataset.html#module-graphnet.data.dataset"><code class="xref">graphnet.data.dataset</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.data.dataset.dataset.html#module-graphnet.data.dataset.dataset"><code class="xref">graphnet.data.dataset.dataset</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.data.dataset.parquet.html#module-graphnet.data.dataset.parquet"><code class="xref">graphnet.data.dataset.parquet</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.data.dataset.parquet.parquet_dataset.html#module-graphnet.data.dataset.parquet.parquet_dataset"><code class="xref">graphnet.data.dataset.parquet.parquet_dataset</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.data.dataset.sqlite.html#module-graphnet.data.dataset.sqlite"><code class="xref">graphnet.data.dataset.sqlite</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.data.dataset.sqlite.sqlite_dataset.html#module-graphnet.data.dataset.sqlite.sqlite_dataset"><code class="xref">graphnet.data.dataset.sqlite.sqlite_dataset</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html#module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed"><code class="xref">graphnet.data.dataset.sqlite.sqlite_dataset_perturbed</code></a></td><td>
+       <em></em></td></tr>
      <tr class="cg-1">
        <td></td>
        <td>&#160;&#160;&#160;
@@ -455,6 +495,11 @@ <h1>Python Module Index</h1>
        <td>&#160;&#160;&#160;
        <a href="api/graphnet.data.parquet.parquet_dataconverter.html#module-graphnet.data.parquet.parquet_dataconverter"><code class="xref">graphnet.data.parquet.parquet_dataconverter</code></a></td><td>
        <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.data.pipeline.html#module-graphnet.data.pipeline"><code class="xref">graphnet.data.pipeline</code></a></td><td>
+       <em></em></td></tr>
      <tr class="cg-1">
        <td></td>
        <td>&#160;&#160;&#160;
@@ -495,6 +540,156 @@ <h1>Python Module Index</h1>
        <td>&#160;&#160;&#160;
        <a href="api/graphnet.deployment.html#module-graphnet.deployment"><code class="xref">graphnet.deployment</code></a></td><td>
        <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.deployment.i3modules.graphnet_module.html#module-graphnet.deployment.i3modules.graphnet_module"><code class="xref">graphnet.deployment.i3modules.graphnet_module</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.html#module-graphnet.models"><code class="xref">graphnet.models</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.coarsening.html#module-graphnet.models.coarsening"><code class="xref">graphnet.models.coarsening</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.components.html#module-graphnet.models.components"><code class="xref">graphnet.models.components</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.components.layers.html#module-graphnet.models.components.layers"><code class="xref">graphnet.models.components.layers</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.components.pool.html#module-graphnet.models.components.pool"><code class="xref">graphnet.models.components.pool</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.detector.html#module-graphnet.models.detector"><code class="xref">graphnet.models.detector</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.detector.detector.html#module-graphnet.models.detector.detector"><code class="xref">graphnet.models.detector.detector</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.detector.icecube.html#module-graphnet.models.detector.icecube"><code class="xref">graphnet.models.detector.icecube</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.detector.prometheus.html#module-graphnet.models.detector.prometheus"><code class="xref">graphnet.models.detector.prometheus</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.gnn.html#module-graphnet.models.gnn"><code class="xref">graphnet.models.gnn</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.gnn.convnet.html#module-graphnet.models.gnn.convnet"><code class="xref">graphnet.models.gnn.convnet</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.gnn.dynedge.html#module-graphnet.models.gnn.dynedge"><code class="xref">graphnet.models.gnn.dynedge</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.gnn.dynedge_jinst.html#module-graphnet.models.gnn.dynedge_jinst"><code class="xref">graphnet.models.gnn.dynedge_jinst</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.gnn.dynedge_kaggle_tito.html#module-graphnet.models.gnn.dynedge_kaggle_tito"><code class="xref">graphnet.models.gnn.dynedge_kaggle_tito</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.gnn.gnn.html#module-graphnet.models.gnn.gnn"><code class="xref">graphnet.models.gnn.gnn</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.graphs.html#module-graphnet.models.graphs"><code class="xref">graphnet.models.graphs</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.graphs.edges.html#module-graphnet.models.graphs.edges"><code class="xref">graphnet.models.graphs.edges</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.graphs.edges.edges.html#module-graphnet.models.graphs.edges.edges"><code class="xref">graphnet.models.graphs.edges.edges</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.graphs.graph_definition.html#module-graphnet.models.graphs.graph_definition"><code class="xref">graphnet.models.graphs.graph_definition</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.graphs.graphs.html#module-graphnet.models.graphs.graphs"><code class="xref">graphnet.models.graphs.graphs</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.graphs.nodes.html#module-graphnet.models.graphs.nodes"><code class="xref">graphnet.models.graphs.nodes</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.graphs.nodes.nodes.html#module-graphnet.models.graphs.nodes.nodes"><code class="xref">graphnet.models.graphs.nodes.nodes</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.model.html#module-graphnet.models.model"><code class="xref">graphnet.models.model</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.standard_model.html#module-graphnet.models.standard_model"><code class="xref">graphnet.models.standard_model</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.task.html#module-graphnet.models.task"><code class="xref">graphnet.models.task</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.task.classification.html#module-graphnet.models.task.classification"><code class="xref">graphnet.models.task.classification</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.task.reconstruction.html#module-graphnet.models.task.reconstruction"><code class="xref">graphnet.models.task.reconstruction</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.task.task.html#module-graphnet.models.task.task"><code class="xref">graphnet.models.task.task</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.models.utils.html#module-graphnet.models.utils"><code class="xref">graphnet.models.utils</code></a></td><td>
+       <em></em></td></tr>
      <tr class="cg-1">
        <td></td>
        <td>&#160;&#160;&#160;
@@ -515,6 +710,26 @@ <h1>Python Module Index</h1>
        <td>&#160;&#160;&#160;
        <a href="api/graphnet.training.html#module-graphnet.training"><code class="xref">graphnet.training</code></a></td><td>
        <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.training.callbacks.html#module-graphnet.training.callbacks"><code class="xref">graphnet.training.callbacks</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.training.labels.html#module-graphnet.training.labels"><code class="xref">graphnet.training.labels</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.training.loss_functions.html#module-graphnet.training.loss_functions"><code class="xref">graphnet.training.loss_functions</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.training.utils.html#module-graphnet.training.utils"><code class="xref">graphnet.training.utils</code></a></td><td>
+       <em></em></td></tr>
      <tr class="cg-1">
        <td></td>
        <td>&#160;&#160;&#160;
@@ -530,6 +745,41 @@ <h1>Python Module Index</h1>
        <td>&#160;&#160;&#160;
        <a href="api/graphnet.utilities.argparse.html#module-graphnet.utilities.argparse"><code class="xref">graphnet.utilities.argparse</code></a></td><td>
        <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.utilities.config.html#module-graphnet.utilities.config"><code class="xref">graphnet.utilities.config</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.utilities.config.base_config.html#module-graphnet.utilities.config.base_config"><code class="xref">graphnet.utilities.config.base_config</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.utilities.config.configurable.html#module-graphnet.utilities.config.configurable"><code class="xref">graphnet.utilities.config.configurable</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.utilities.config.dataset_config.html#module-graphnet.utilities.config.dataset_config"><code class="xref">graphnet.utilities.config.dataset_config</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.utilities.config.model_config.html#module-graphnet.utilities.config.model_config"><code class="xref">graphnet.utilities.config.model_config</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.utilities.config.parsing.html#module-graphnet.utilities.config.parsing"><code class="xref">graphnet.utilities.config.parsing</code></a></td><td>
+       <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.utilities.config.training_config.html#module-graphnet.utilities.config.training_config"><code class="xref">graphnet.utilities.config.training_config</code></a></td><td>
+       <em></em></td></tr>
      <tr class="cg-1">
        <td></td>
        <td>&#160;&#160;&#160;
@@ -550,6 +800,11 @@ <h1>Python Module Index</h1>
        <td>&#160;&#160;&#160;
        <a href="api/graphnet.utilities.logging.html#module-graphnet.utilities.logging"><code class="xref">graphnet.utilities.logging</code></a></td><td>
        <em></em></td></tr>
+     <tr class="cg-1">
+       <td></td>
+       <td>&#160;&#160;&#160;
+       <a href="api/graphnet.utilities.maths.html#module-graphnet.utilities.maths"><code class="xref">graphnet.utilities.maths</code></a></td><td>
+       <em></em></td></tr>
    </table>
 
 
@@ -575,7 +830,7 @@ <h1>Python Module Index</h1>
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/search.html b/search.html
index 32901805c..4199f1026 100644
--- a/search.html
+++ b/search.html
@@ -361,7 +361,7 @@ <h1 id="search-documentation">Search</h1>
               
           </div>
             Created using
-            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5.
+            <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6.
              and
             <a href="https://github.com/bashtage/sphinx-material/">Material for
               Sphinx</a>
diff --git a/searchindex.js b/searchindex.js
index 4df38d653..200c93687 100644
--- a/searchindex.js
+++ b/searchindex.js
@@ -1 +1 @@
-Search.setIndex({"docnames": ["about", "api/graphnet", "api/graphnet.constants", "api/graphnet.data", "api/graphnet.data.constants", "api/graphnet.data.dataconverter", "api/graphnet.data.dataloader", "api/graphnet.data.dataset", "api/graphnet.data.dataset.dataset", "api/graphnet.data.dataset.parquet", "api/graphnet.data.dataset.parquet.parquet_dataset", "api/graphnet.data.dataset.sqlite", "api/graphnet.data.dataset.sqlite.sqlite_dataset", "api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed", "api/graphnet.data.extractors", "api/graphnet.data.extractors.i3extractor", "api/graphnet.data.extractors.i3featureextractor", "api/graphnet.data.extractors.i3genericextractor", "api/graphnet.data.extractors.i3hybridrecoextractor", "api/graphnet.data.extractors.i3ntmuonlabelsextractor", "api/graphnet.data.extractors.i3particleextractor", "api/graphnet.data.extractors.i3pisaextractor", "api/graphnet.data.extractors.i3quesoextractor", "api/graphnet.data.extractors.i3retroextractor", "api/graphnet.data.extractors.i3splinempeextractor", "api/graphnet.data.extractors.i3truthextractor", "api/graphnet.data.extractors.i3tumextractor", "api/graphnet.data.extractors.utilities", "api/graphnet.data.extractors.utilities.collections", "api/graphnet.data.extractors.utilities.frames", "api/graphnet.data.extractors.utilities.types", "api/graphnet.data.parquet", "api/graphnet.data.parquet.parquet_dataconverter", "api/graphnet.data.pipeline", "api/graphnet.data.sqlite", "api/graphnet.data.sqlite.sqlite_dataconverter", "api/graphnet.data.sqlite.sqlite_utilities", "api/graphnet.data.utilities", "api/graphnet.data.utilities.parquet_to_sqlite", "api/graphnet.data.utilities.random", "api/graphnet.data.utilities.string_selection_resolver", "api/graphnet.deployment", "api/graphnet.deployment.i3modules", "api/graphnet.deployment.i3modules.deployer", "api/graphnet.deployment.i3modules.graphnet_module", "api/graphnet.models", "api/graphnet.models.coarsening", "api/graphnet.models.components", "api/graphnet.models.components.layers", "api/graphnet.models.components.pool", "api/graphnet.models.detector", "api/graphnet.models.detector.detector", "api/graphnet.models.detector.icecube", "api/graphnet.models.detector.prometheus", "api/graphnet.models.gnn", "api/graphnet.models.gnn.convnet", "api/graphnet.models.gnn.dynedge", "api/graphnet.models.gnn.dynedge_jinst", "api/graphnet.models.gnn.dynedge_kaggle_tito", "api/graphnet.models.gnn.gnn", "api/graphnet.models.graphs", "api/graphnet.models.graphs.edges", "api/graphnet.models.graphs.edges.edges", "api/graphnet.models.graphs.graph_definition", "api/graphnet.models.graphs.graphs", "api/graphnet.models.graphs.nodes", "api/graphnet.models.graphs.nodes.nodes", "api/graphnet.models.model", "api/graphnet.models.standard_model", "api/graphnet.models.task", "api/graphnet.models.task.classification", "api/graphnet.models.task.reconstruction", "api/graphnet.models.task.task", "api/graphnet.models.utils", "api/graphnet.pisa", "api/graphnet.pisa.fitting", "api/graphnet.pisa.plotting", "api/graphnet.training", "api/graphnet.training.callbacks", "api/graphnet.training.labels", "api/graphnet.training.loss_functions", "api/graphnet.training.utils", "api/graphnet.training.weight_fitting", "api/graphnet.utilities", "api/graphnet.utilities.argparse", "api/graphnet.utilities.config", "api/graphnet.utilities.config.base_config", "api/graphnet.utilities.config.configurable", "api/graphnet.utilities.config.dataset_config", "api/graphnet.utilities.config.model_config", "api/graphnet.utilities.config.parsing", "api/graphnet.utilities.config.training_config", "api/graphnet.utilities.decorators", "api/graphnet.utilities.filesys", "api/graphnet.utilities.imports", "api/graphnet.utilities.logging", "api/graphnet.utilities.maths", "api/modules", "contribute", "index", "install"], "filenames": ["about.md", "api/graphnet.rst", "api/graphnet.constants.rst", "api/graphnet.data.rst", "api/graphnet.data.constants.rst", "api/graphnet.data.dataconverter.rst", "api/graphnet.data.dataloader.rst", "api/graphnet.data.dataset.rst", "api/graphnet.data.dataset.dataset.rst", "api/graphnet.data.dataset.parquet.rst", "api/graphnet.data.dataset.parquet.parquet_dataset.rst", "api/graphnet.data.dataset.sqlite.rst", "api/graphnet.data.dataset.sqlite.sqlite_dataset.rst", "api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.rst", "api/graphnet.data.extractors.rst", "api/graphnet.data.extractors.i3extractor.rst", "api/graphnet.data.extractors.i3featureextractor.rst", "api/graphnet.data.extractors.i3genericextractor.rst", "api/graphnet.data.extractors.i3hybridrecoextractor.rst", "api/graphnet.data.extractors.i3ntmuonlabelsextractor.rst", "api/graphnet.data.extractors.i3particleextractor.rst", "api/graphnet.data.extractors.i3pisaextractor.rst", "api/graphnet.data.extractors.i3quesoextractor.rst", "api/graphnet.data.extractors.i3retroextractor.rst", "api/graphnet.data.extractors.i3splinempeextractor.rst", "api/graphnet.data.extractors.i3truthextractor.rst", "api/graphnet.data.extractors.i3tumextractor.rst", "api/graphnet.data.extractors.utilities.rst", "api/graphnet.data.extractors.utilities.collections.rst", "api/graphnet.data.extractors.utilities.frames.rst", "api/graphnet.data.extractors.utilities.types.rst", "api/graphnet.data.parquet.rst", "api/graphnet.data.parquet.parquet_dataconverter.rst", "api/graphnet.data.pipeline.rst", "api/graphnet.data.sqlite.rst", "api/graphnet.data.sqlite.sqlite_dataconverter.rst", "api/graphnet.data.sqlite.sqlite_utilities.rst", "api/graphnet.data.utilities.rst", "api/graphnet.data.utilities.parquet_to_sqlite.rst", "api/graphnet.data.utilities.random.rst", "api/graphnet.data.utilities.string_selection_resolver.rst", "api/graphnet.deployment.rst", "api/graphnet.deployment.i3modules.rst", "api/graphnet.deployment.i3modules.deployer.rst", "api/graphnet.deployment.i3modules.graphnet_module.rst", "api/graphnet.models.rst", "api/graphnet.models.coarsening.rst", "api/graphnet.models.components.rst", "api/graphnet.models.components.layers.rst", "api/graphnet.models.components.pool.rst", "api/graphnet.models.detector.rst", "api/graphnet.models.detector.detector.rst", "api/graphnet.models.detector.icecube.rst", "api/graphnet.models.detector.prometheus.rst", "api/graphnet.models.gnn.rst", "api/graphnet.models.gnn.convnet.rst", "api/graphnet.models.gnn.dynedge.rst", "api/graphnet.models.gnn.dynedge_jinst.rst", "api/graphnet.models.gnn.dynedge_kaggle_tito.rst", "api/graphnet.models.gnn.gnn.rst", "api/graphnet.models.graphs.rst", "api/graphnet.models.graphs.edges.rst", "api/graphnet.models.graphs.edges.edges.rst", "api/graphnet.models.graphs.graph_definition.rst", "api/graphnet.models.graphs.graphs.rst", "api/graphnet.models.graphs.nodes.rst", "api/graphnet.models.graphs.nodes.nodes.rst", "api/graphnet.models.model.rst", "api/graphnet.models.standard_model.rst", "api/graphnet.models.task.rst", "api/graphnet.models.task.classification.rst", "api/graphnet.models.task.reconstruction.rst", "api/graphnet.models.task.task.rst", "api/graphnet.models.utils.rst", "api/graphnet.pisa.rst", "api/graphnet.pisa.fitting.rst", "api/graphnet.pisa.plotting.rst", "api/graphnet.training.rst", "api/graphnet.training.callbacks.rst", "api/graphnet.training.labels.rst", "api/graphnet.training.loss_functions.rst", "api/graphnet.training.utils.rst", "api/graphnet.training.weight_fitting.rst", "api/graphnet.utilities.rst", "api/graphnet.utilities.argparse.rst", "api/graphnet.utilities.config.rst", "api/graphnet.utilities.config.base_config.rst", "api/graphnet.utilities.config.configurable.rst", "api/graphnet.utilities.config.dataset_config.rst", "api/graphnet.utilities.config.model_config.rst", "api/graphnet.utilities.config.parsing.rst", "api/graphnet.utilities.config.training_config.rst", "api/graphnet.utilities.decorators.rst", "api/graphnet.utilities.filesys.rst", "api/graphnet.utilities.imports.rst", "api/graphnet.utilities.logging.rst", "api/graphnet.utilities.maths.rst", "api/modules.rst", "contribute.md", "index.rst", "install.md"], "titles": ["About", "API", "constants", "data", "constants", "dataconverter", "dataloader", "dataset", "dataset", "parquet", "parquet_dataset", "sqlite", "sqlite_dataset", "sqlite_dataset_perturbed", "extractors", "i3extractor", "i3featureextractor", "i3genericextractor", "i3hybridrecoextractor", "i3ntmuonlabelsextractor", "i3particleextractor", "i3pisaextractor", "i3quesoextractor", "i3retroextractor", "i3splinempeextractor", "i3truthextractor", "i3tumextractor", "utilities", "collections", "frames", "types", "parquet", "parquet_dataconverter", "pipeline", "sqlite", "sqlite_dataconverter", "sqlite_utilities", "utilities", "parquet_to_sqlite", "random", "string_selection_resolver", "deployment", "i3modules", "deployer", "graphnet_module", "models", "coarsening", "components", "layers", "pool", "detector", "detector", "icecube", "prometheus", "gnn", "convnet", "dynedge", "dynedge_jinst", "dynedge_kaggle_tito", "gnn", "graphs", "edges", "edges", "graph_definition", "graphs", "nodes", "nodes", "model", "standard_model", "task", "classification", "reconstruction", "task", "utils", "pisa", "fitting", "plotting", "training", "callbacks", "labels", "loss_functions", "utils", "weight_fitting", "utilities", "argparse", "config", "base_config", "configurable", "dataset_config", "model_config", "parsing", "training_config", "decorators", "filesys", "imports", "logging", "maths", "src", "Contribute", "About", "Install"], "terms": {"graphnet": [0, 1, 2, 3, 4, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 35, 36, 37, 38, 39, 40, 41, 75, 76, 77, 82, 83, 84, 93, 94, 95, 98, 99, 100], "i": [0, 1, 15, 17, 28, 29, 30, 35, 36, 39, 40, 76, 82, 84, 93, 94, 95, 98, 99, 100], "an": [0, 5, 30, 32, 35, 40, 93, 95, 98, 99, 100], "open": [0, 98, 99], "sourc": [0, 4, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 35, 36, 38, 39, 40, 75, 76, 82, 84, 93, 94, 95, 98, 99], "python": [0, 1, 5, 14, 15, 17, 28, 30, 98, 99, 100], "framework": [0, 99], "aim": [0, 1, 98, 99], "provid": [0, 1, 98, 99, 100], "high": [0, 99], "qualiti": [0, 99], "user": [0, 99, 100], "friendli": [0, 99], "end": [0, 1, 5, 32, 35, 99], "function": [0, 5, 30, 36, 39, 75, 76, 83, 93, 94, 99], "perform": [0, 99], "reconstruct": [0, 1, 16, 18, 19, 23, 24, 26, 41, 45, 69, 99], "task": [0, 1, 45, 98, 99], "neutrino": [0, 1, 75, 99], "telescop": [0, 1, 99], "us": [0, 1, 2, 4, 5, 15, 20, 25, 27, 28, 32, 35, 36, 37, 38, 40, 41, 75, 82, 83, 84, 94, 95, 98, 99, 100], "graph": [0, 1, 45, 98, 99], "neural": [0, 1, 99], "network": [0, 1, 99], "gnn": [0, 1, 45, 99, 100], "make": [0, 5, 82, 98, 99, 100], "fast": [0, 99, 100], "easi": [0, 99], "train": [0, 1, 40, 41, 82, 84, 97, 99, 100], "complex": [0, 99], "model": [0, 1, 41, 76, 77, 84, 97, 99, 100], "can": [0, 1, 15, 17, 20, 38, 75, 76, 82, 84, 98, 99, 100], "event": [0, 1, 22, 36, 38, 40, 75, 82, 99], "state": [0, 99], "art": [0, 99], "arbitrari": [0, 99], "detector": [0, 1, 25, 45, 99], "configur": [0, 1, 75, 83, 85, 95, 99], "infer": [0, 1, 41, 99, 100], "time": [0, 4, 36, 95, 99, 100], "ar": [0, 1, 4, 5, 17, 30, 32, 35, 38, 40, 75, 82, 98, 99, 100], "order": [0, 28, 99], "magnitud": [0, 99], "faster": [0, 99], "than": [0, 95, 99], "tradit": [0, 99], "techniqu": [0, 99], "common": [0, 1, 92, 94, 99], "ml": [0, 1, 99], "develop": [0, 1, 98, 99, 100], "physicist": [0, 1, 99], "wish": [0, 98, 99], "tool": [0, 1, 99], "research": [0, 99], "By": [0, 38, 99], "unit": [0, 5, 94, 98, 99], "both": [0, 17, 76, 99], "group": [0, 5, 32, 35, 99], "increas": [0, 99], "longev": [0, 99], "usabl": [0, 99], "individu": [0, 5, 99], "code": [0, 25, 36, 99], "contribut": [0, 99, 100], "from": [0, 1, 14, 15, 17, 19, 20, 22, 28, 29, 30, 35, 38, 76, 95, 98, 99, 100], "build": [0, 1, 99], "gener": [0, 5, 17, 99], "reusabl": [0, 99], "softwar": [0, 99], "packag": [0, 1, 39, 93, 94, 98, 99, 100], "base": [0, 4, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35, 38, 40, 75, 82, 84, 94, 95, 99], "engin": [0, 99], "best": [0, 98, 99], "practic": [0, 98, 99], "lower": [0, 76, 99], "technic": [0, 99], "threshold": [0, 99], "most": [0, 1, 40, 99, 100], "scientif": [0, 1, 99], "problem": [0, 98, 99], "The": [0, 5, 28, 30, 35, 36, 75, 76, 99], "improv": [0, 1, 84, 99], "classif": [0, 1, 45, 69, 99], "yield": [0, 75, 99], "veri": [0, 40, 99], "accur": [0, 99], "e": [0, 1, 5, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 28, 30, 32, 35, 36, 40, 82, 95, 98, 99, 100], "g": [0, 1, 5, 25, 28, 30, 32, 35, 36, 40, 82, 95, 98, 99, 100], "low": [0, 99], "energi": [0, 4, 82, 99], "observ": [0, 99], "icecub": [0, 1, 16, 29, 30, 45, 50, 94, 99, 100], "here": [0, 98, 99, 100], "implement": [0, 1, 5, 15, 31, 32, 34, 35, 98, 99], "wa": [0, 99], "appli": [0, 15, 99], "oscil": [0, 74, 99], "lead": [0, 99], "signific": [0, 99], "angular": [0, 99], "rang": [0, 99], "relev": [0, 1, 30, 39, 93, 98, 99], "studi": [0, 99], "furthermor": [0, 99], "shown": [0, 99], "could": [0, 98, 99], "muon": [0, 19, 99], "v": [0, 99], "therebi": [0, 1, 99], "effici": [0, 99], "puriti": [0, 99], "sampl": [0, 40, 99], "analysi": [0, 99, 100], "similarli": [0, 30, 99], "ha": [0, 5, 30, 32, 35, 36, 93, 99, 100], "great": [0, 99], "point": [0, 24, 99], "analys": [0, 41, 74, 99], "final": [0, 99], "millisecond": [0, 99], "allow": [0, 41, 99, 100], "whole": [0, 99], "new": [0, 1, 35, 98, 99], "type": [0, 5, 14, 15, 27, 28, 29, 32, 35, 36, 38, 39, 40, 75, 76, 82, 84, 93, 94, 95, 98, 99], "cosmic": [0, 99], "alert": [0, 99], "which": [0, 15, 16, 25, 29, 40, 75, 84, 99, 100], "were": [0, 99], "previous": [0, 99], "unfeas": [0, 99], "possibl": [0, 28, 98, 99], "identifi": [0, 5, 25, 99], "10": [0, 84, 99], "tev": [0, 99], "monitor": [0, 99], "rate": [0, 99], "direct": [0, 99], "real": [0, 99], "thi": [0, 3, 5, 15, 17, 30, 32, 35, 36, 39, 75, 76, 82, 95, 98, 99, 100], "enabl": [0, 3, 99], "first": [0, 98, 99], "ever": [0, 99], "despit": [0, 99], "larg": [0, 99], "background": [0, 99], "origin": [0, 75, 99], "compris": [0, 99], "number": [0, 5, 32, 35, 40, 84, 99], "modul": [0, 3, 30, 41, 74, 77, 83, 94, 99], "necessari": [0, 28, 98, 99], "workflow": [0, 99], "ingest": [0, 1, 3, 99], "raw": [0, 99], "data": [0, 1, 4, 5, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 84, 94, 97, 99, 100], "domain": [0, 1, 3, 41, 99], "specif": [0, 1, 3, 5, 16, 30, 31, 32, 34, 35, 36, 41, 98, 99, 100], "format": [0, 1, 3, 5, 28, 32, 35, 76, 98, 99, 100], "deploi": [0, 1, 41, 99], "chain": [0, 1, 41, 99, 100], "illustr": [0, 98, 99], "figur": [0, 76, 99], "level": [0, 25, 36, 95, 99, 100], "overview": [0, 99], "typic": [0, 28, 99], "convert": [0, 1, 3, 5, 28, 32, 35, 38, 99, 100], "industri": [0, 3, 99], "standard": [0, 3, 4, 5, 32, 35, 40, 84, 98, 99], "intermedi": [0, 1, 3, 5, 32, 35, 99, 100], "file": [0, 1, 3, 5, 15, 28, 32, 35, 38, 39, 75, 84, 93, 95, 99, 100], "read": [0, 3, 28, 99, 100], "simpl": [0, 99], "physic": [0, 1, 15, 29, 30, 41, 99], "orient": [0, 99], "compon": [0, 1, 45, 99], "manag": [0, 15, 77, 99], "experi": [0, 1, 77, 99], "log": [0, 1, 77, 83, 99, 100], "deploy": [0, 1, 42, 97, 99], "modular": [0, 99], "subclass": [0, 99], "torch": [0, 94, 99, 100], "nn": [0, 99], "mean": [0, 5, 32, 35, 99], "onli": [0, 1, 75, 82, 94, 99, 100], "need": [0, 28, 99, 100], "import": [0, 1, 36, 83, 99], "few": [0, 98, 99], "exist": [0, 35, 36, 99], "purpos": [0, 99], "built": [0, 99], "them": [0, 1, 28, 75, 99, 100], "togeth": [0, 99], "form": [0, 99], "complet": [0, 99], "extend": [0, 1, 99], "suit": [0, 99], "through": [0, 99], "layer": [0, 45, 47, 99], "connect": [0, 99], "etc": [0, 95, 99], "optimis": [0, 1, 99], "differ": [0, 15, 98, 99, 100], "track": [0, 15, 19, 98, 99], "These": [0, 98, 99], "prepar": [0, 99], "satisfi": [0, 99], "o": [0, 99], "load": [0, 39, 99], "requir": [0, 21, 36, 99, 100], "when": [0, 5, 28, 32, 35, 36, 95, 98, 99, 100], "batch": [0, 84, 99], "do": [0, 98, 99, 100], "predict": [0, 20, 24, 26, 99], "either": [0, 99, 100], "contain": [0, 5, 28, 29, 32, 35, 82, 84, 99, 100], "imag": [0, 1, 98, 99, 100], "portabl": [0, 99], "depend": [0, 99, 100], "free": [0, 99], "split": [0, 99], "up": [0, 5, 32, 35, 98, 99, 100], "interfac": [0, 74, 99, 100], "block": [0, 1, 99], "pre": [0, 98, 99], "directli": [0, 15, 99], "while": [0, 17, 99], "continu": [0, 99], "expand": [0, 99], "": [0, 5, 15, 28, 35, 38, 82, 84, 95, 99, 100], "capabl": [0, 99], "project": [0, 98, 99], "receiv": [0, 99], "fund": [0, 99], "european": [0, 99], "union": [0, 17, 28, 30, 93, 99], "horizon": [0, 99], "2020": [0, 99], "innov": [0, 99], "programm": [0, 99], "under": [0, 99], "mari": [0, 99], "sk\u0142odowska": [0, 99], "curi": [0, 99], "grant": [0, 99], "agreement": [0, 98, 99], "No": [0, 99], "890778": [0, 99], "work": [0, 4, 29, 98, 99, 100], "rasmu": [0, 99], "\u00f8rs\u00f8e": [0, 99], "partli": [0, 99], "punch4nfdi": [0, 99], "consortium": [0, 99], "support": [0, 30, 98, 99, 100], "dfg": [0, 99], "nfdi": [0, 99], "39": [0, 99, 100], "1": [0, 5, 28, 32, 35, 40, 82, 99, 100], "germani": [0, 99], "conveni": [1, 98, 100], "collabor": 1, "solv": [1, 98], "It": [1, 28, 36, 98], "leverag": 1, "advanc": 1, "machin": [1, 100], "learn": [1, 100], "without": [1, 75, 100], "have": [1, 5, 17, 32, 35, 36, 40, 98, 100], "expert": 1, "themselv": 1, "acceler": 1, "area": 1, "phyic": 1, "design": 1, "principl": 1, "all": [1, 5, 15, 17, 32, 35, 36, 95, 98, 100], "streamlin": 1, "process": [1, 5, 15, 98, 100], "transform": [1, 82], "extens": [1, 93], "basic": 1, "across": [1, 2, 30, 37, 83, 84, 95], "variou": 1, "easili": 1, "architectur": 1, "main": [1, 98, 100], "featur": [1, 3, 4, 5, 16, 98], "i3": [1, 5, 15, 29, 30, 32, 35, 39, 93, 100], "more": [1, 36, 39, 95], "index": [1, 5, 30, 36], "sqlite": [1, 3, 7, 35, 36, 38, 100], "suitabl": 1, "plug": 1, "plai": 1, "abstract": [1, 5], "awai": 1, "detail": [1, 100], "expos": 1, "physicst": 1, "what": [1, 98], "i3modul": [1, 41], "includ": [1, 75, 98], "docker": 1, "run": [1, 38], "containeris": 1, "fashion": 1, "subpackag": [1, 3, 7, 14, 41, 45, 60, 83], "dataset": [1, 3, 19, 40, 84], "extractor": [1, 3, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 35], "parquet": [1, 3, 7, 32, 38, 100], "util": [1, 3, 14, 28, 29, 30, 36, 38, 39, 40, 45, 77, 84, 93, 94, 95, 97], "constant": [1, 3, 97], "dataconvert": [1, 3, 32, 35], "dataload": [1, 3], "pipelin": [1, 3], "coarsen": [1, 45], "standard_model": [1, 45], "pisa": [1, 21, 75, 76, 94, 97, 100], "fit": [1, 74, 76, 82], "plot": [1, 74], "callback": [1, 77], "label": [1, 19, 22, 76, 77], "loss_funct": [1, 77], "weight_fit": [1, 77], "config": [1, 40, 75, 83, 84], "argpars": [1, 83], "decor": [1, 5, 83, 94], "filesi": [1, 83], "math": [1, 83], "submodul": [1, 3, 7, 9, 11, 14, 27, 31, 34, 37, 42, 45, 47, 50, 54, 60, 61, 65, 69, 74, 77, 83, 85], "global": [2, 4], "i3extractor": [3, 5, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35], "i3featureextractor": [3, 4, 14, 35], "i3genericextractor": [3, 14, 35], "i3hybridrecoextractor": [3, 14], "i3ntmuonlabelsextractor": [3, 14], "i3particleextractor": [3, 14], "i3pisaextractor": [3, 14], "i3quesoextractor": [3, 14], "i3retroextractor": [3, 14], "i3splinempeextractor": [3, 14], "i3truthextractor": [3, 4, 14], "i3tumextractor": [3, 14], "parquet_dataconvert": [3, 31], "sqlite_dataconvert": [3, 34], "sqlite_util": [3, 34], "parquet_to_sqlit": [3, 37], "random": [3, 37, 40], "string_selection_resolv": [3, 37], "truth": [3, 4, 16, 25, 36, 82], "fileset": [3, 5], "init_global_index": [3, 5], "cache_output_fil": [3, 5], "class": [4, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 30, 31, 32, 34, 35, 38, 40, 75, 82, 84, 95, 98], "object": [4, 5, 15, 17, 28, 30, 75, 84, 95], "namespac": 4, "name": [4, 5, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 30, 32, 35, 36, 38, 75, 82, 84, 95, 98, 100], "icecube86": 4, "dom_x": 4, "dom_i": 4, "dom_z": 4, "dom_tim": 4, "charg": 4, "rde": 4, "pmt_area": 4, "deepcor": [4, 16], "upgrad": [4, 16, 100], "string": [4, 5, 28, 32, 35, 40], "pmt_number": 4, "dom_numb": 4, "pmt_dir_x": 4, "pmt_dir_i": 4, "pmt_dir_z": 4, "dom_typ": 4, "prometheu": [4, 45, 50], "sensor_pos_x": 4, "sensor_pos_i": 4, "sensor_pos_z": 4, "t": [4, 30, 36, 76, 100], "kaggl": 4, "x": [4, 5, 25, 32, 35, 76, 82], "y": [4, 25, 76, 100], "z": [4, 5, 25, 32, 35, 100], "auxiliari": 4, "energy_track": 4, "position_x": 4, "position_i": 4, "position_z": 4, "azimuth": 4, "zenith": 4, "pid": [4, 40], "elast": 4, "sim_typ": 4, "interaction_typ": 4, "interaction_tim": 4, "inelast": 4, "stopped_muon": 4, "injection_energi": 4, "injection_typ": 4, "injection_interaction_typ": 4, "injection_zenith": 4, "injection_azimuth": 4, "injection_bjorkenx": 4, "injection_bjorkeni": 4, "injection_position_x": 4, "injection_position_i": 4, "injection_position_z": 4, "injection_column_depth": 4, "primary_lepton_1_typ": 4, "primary_hadron_1_typ": 4, "primary_lepton_1_position_x": 4, "primary_lepton_1_position_i": 4, "primary_lepton_1_position_z": 4, "primary_hadron_1_position_x": 4, "primary_hadron_1_position_i": 4, "primary_hadron_1_position_z": 4, "primary_lepton_1_direction_theta": 4, "primary_lepton_1_direction_phi": 4, "primary_hadron_1_direction_theta": 4, "primary_hadron_1_direction_phi": 4, "primary_lepton_1_energi": 4, "primary_hadron_1_energi": 4, "total_energi": 4, "i3_fil": [5, 15], "str": [5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 35, 36, 38, 39, 40, 75, 82, 84, 93, 95], "gcd_file": [5, 15], "paramet": [5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 35, 36, 38, 39, 40, 75, 76, 82, 84, 93, 94, 95], "output_fil": [5, 32, 35], "global_index": 5, "avail": [5, 17, 94], "pool": [5, 45, 47], "worker": [5, 32, 35, 39, 84, 95], "return": [5, 15, 28, 29, 30, 32, 35, 36, 38, 39, 40, 75, 76, 82, 84, 93, 94, 95], "none": [5, 15, 17, 25, 29, 30, 32, 35, 36, 38, 40, 75, 82, 84, 93, 95], "synchron": 5, "list": [5, 15, 17, 25, 28, 30, 32, 35, 36, 38, 39, 40, 76, 82, 93, 95], "process_method": 5, "cach": 5, "output": [5, 32, 35, 38, 75, 82, 100], "typevar": 5, "f": 5, "bound": [5, 76], "callabl": [5, 30, 82, 94], "ani": [5, 28, 29, 30, 32, 35, 76, 82, 84, 95, 100], "outdir": [5, 32, 35, 38, 75], "gcd_rescu": [5, 32, 35, 93], "nb_files_to_batch": [5, 32, 35], "sequential_batch_pattern": [5, 32, 35], "input_file_batch_pattern": [5, 32, 35], "index_column": [5, 32, 35, 36, 40, 75, 82], "icetray_verbos": [5, 32, 35], "abc": [5, 15, 82], "logger": [5, 15, 38, 40, 82, 83, 95, 100], "construct": [5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35, 38, 40, 75, 82, 84, 95], "regular": [5, 30, 32, 35], "express": [5, 32, 35], "accord": [5, 32, 35], "match": [5, 32, 35, 82, 93], "certain": [5, 32, 35, 38, 75], "pattern": [5, 32, 35], "wildcard": [5, 32, 35], "same": [5, 30, 32, 35, 36, 95], "input": [5, 32, 35], "replac": [5, 32, 35], "period": [5, 32, 35], "special": [5, 17, 32, 35], "interpret": [5, 32, 35], "liter": [5, 32, 35], "charact": [5, 32, 35], "regex": [5, 32, 35], "For": [5, 30, 32, 35], "instanc": [5, 15, 25, 30, 32, 35, 75, 100], "A": [5, 32, 35, 75, 82, 100], "_": [5, 32, 35], "0": [5, 32, 35, 40, 75, 76], "9": [5, 32, 35], "5": [5, 32, 35, 40, 84, 100], "zst": [5, 32, 35], "find": [5, 32, 35, 93], "whose": [5, 32, 35], "one": [5, 32, 35, 36, 93, 98, 100], "capit": [5, 32, 35], "letter": [5, 32, 35], "follow": [5, 32, 35, 82, 98, 100], "underscor": [5, 32, 35], "five": [5, 32, 35], "upgrade_genie_step4_141020_a_000000": [5, 32, 35], "upgrade_genie_step4_141020_a_000001": [5, 32, 35], "upgrade_genie_step4_141020_a_000008": [5, 32, 35], "upgrade_genie_step4_141020_a_000009": [5, 32, 35], "would": [5, 32, 35, 98], "upgrade_genie_step4_141020_a_00000x": [5, 32, 35], "suffix": [5, 32, 35], "upgrade_genie_step4_141020_a_000010": [5, 32, 35], "separ": [5, 28, 32, 35, 100], "upgrade_genie_step4_141020_a_00001x": [5, 32, 35], "int": [5, 19, 22, 32, 35, 40, 75, 82, 84, 95], "properti": [5, 15, 20, 30, 95], "file_suffix": [5, 32, 35], "execut": [5, 36], "method": [5, 15, 27, 28, 29, 30, 32, 35, 82], "set": [5, 17, 98], "inherit": [5, 15, 30, 95], "path": [5, 36, 39, 75, 76, 84, 93, 100], "correspond": [5, 28, 30, 35, 39, 82, 93, 100], "gcd": [5, 15, 29, 39, 93], "save_data": [5, 32, 35], "save": [5, 15, 28, 32, 35, 36, 75, 82, 100], "ordereddict": [5, 32, 35], "extract": [5, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 35, 38, 39], "merge_fil": [5, 32, 35], "input_fil": [5, 32, 35], "merg": [5, 32, 35, 100], "result": [5, 32, 35, 100], "option": [5, 25, 32, 35, 75, 76, 82, 83, 84, 93, 100], "default": [5, 17, 25, 28, 32, 35, 36, 38, 75, 76, 82, 84, 93], "current": [5, 32, 35, 40, 98, 100], "rais": [5, 17, 32], "notimplementederror": [5, 32], "If": [5, 17, 32, 35, 75, 82, 98, 100], "been": [5, 32, 98], "backend": [5, 32, 35], "question": 5, "get_map_funct": 5, "nb_file": 5, "map": [5, 16, 17, 35, 36], "pure": [5, 14, 15, 17, 30], "multiprocess": [5, 100], "tupl": [5, 29, 30, 75, 76, 84], "parquet_dataset": [7, 9], "sqlite_dataset": [7, 11], "sqlite_dataset_perturb": [7, 11], "collect": [14, 15, 27], "i3fram": [14, 15, 17, 29, 30], "frame": [14, 15, 17, 27, 30, 35], "i3extractorcollect": [14, 15], "i3featureextractoricecube86": [14, 16], "i3featureextractoricecubedeepcor": [14, 16], "i3featureextractoricecubeupgrad": [14, 16], "i3pulsenoisetruthflagicecubeupgrad": [14, 16], "i3galacticplanehybridrecoextractor": [14, 18], "i3ntmuonlabelextractor": [14, 19], "i3splinempeicextractor": [14, 24], "inform": [15, 17, 25, 76], "should": [15, 28, 40, 98, 100], "__call__": 15, "icetrai": [15, 29, 30, 94], "keep": 15, "proven": 15, "tabl": [15, 35, 36, 75, 82], "set_fil": 15, "store": [15, 36, 75], "refer": 15, "being": 15, "get": [15, 29, 100], "multipl": [15, 95], "treat": 15, "singl": 15, "pulsemap": [16, 35], "puls": [16, 17, 29, 30, 35, 36], "seri": [16, 17, 29, 30, 36], "86": 16, "nois": [16, 29], "flag": 16, "ad": [16, 75], "kei": [17, 28, 29, 30, 35, 36], "exclude_kei": 17, "dynam": 17, "pars": [17, 76, 83, 84, 85], "call": [17, 30, 35, 75, 82, 95], "tri": [17, 30], "automat": [17, 98], "cast": [17, 30], "done": [17, 95, 98], "recurs": [17, 30, 93], "each": [17, 28, 30, 36, 38, 39, 75, 76, 93], "look": [17, 100], "member": [17, 30, 95], "variabl": [17, 30, 82, 95], "signatur": [17, 30], "similar": [17, 30, 100], "dict": [17, 28, 30, 35, 75, 76, 84], "handl": [17, 84, 95], "hand": 17, "case": [17, 100], "per": [17, 36, 82], "mc": [17, 35, 36], "tree": [17, 35], "trigger": 17, "exclud": [17, 38, 100], "valueerror": 17, "hybrid": 18, "galatict": 18, "plane": 18, "tum": [19, 26], "dnn": [19, 26], "padding_valu": [19, 22], "northeren": 19, "i3particl": 20, "other": [20, 36, 98], "algorithm": 20, "comparison": 20, "quantiti": 21, "select": [22, 40, 82, 98], "queso": 22, "retro": 23, "splinemp": 24, "border": 25, "mctree": [25, 29], "ndarrai": [25, 82], "arrai": [25, 28], "boundari": 25, "volum": 25, "coordin": 25, "particl": [25, 36], "start": [25, 98, 100], "stop": [25, 84], "within": 25, "hard": 25, "i3mctre": 25, "valu": [25, 28, 35, 36, 76, 84], "flatten_nested_dictionari": [27, 28], "serialis": [27, 28], "transpose_list_of_dict": [27, 28], "frame_is_montecarlo": [27, 29], "frame_is_nois": [27, 29], "get_om_keys_and_pulseseri": [27, 29], "is_boost_enum": [27, 30], "is_boost_class": [27, 30], "is_icecube_class": [27, 30], "is_typ": [27, 30], "is_method": [27, 30], "break_cyclic_recurs": [27, 30], "get_member_vari": [27, 30], "cast_object_to_pure_python": [27, 30], "cast_pulse_series_to_pure_python": [27, 30], "manipul": 28, "obj": [28, 30], "parent_kei": 28, "flatten": 28, "nest": 28, "dictionari": [28, 29, 30, 35, 75, 76], "non": [28, 30, 35, 36], "exampl": [28, 40, 100], "d": [28, 98], "b": 28, "c": [28, 100], "2": [28, 75, 76, 100], "a__b": 28, "applic": 28, "combin": 28, "parent": 28, "__": [28, 30], "concaten": 28, "nester": 28, "json": 28, "therefor": 28, "we": [28, 30, 40, 98, 100], "element": [28, 30], "outer": 28, "abl": [28, 100], "de": 28, "transpos": 28, "check": [29, 30, 35, 36, 84, 93, 94, 98, 100], "whether": [29, 30, 35, 36, 93, 94], "mont": 29, "carlo": 29, "simul": 29, "bool": [29, 30, 35, 36, 40, 75, 82, 84, 93, 94, 95], "pulseseri": 29, "calibr": [29, 30], "indici": [29, 40], "gcd_dict": [29, 30], "p": [29, 35], "om": [29, 30], "dataclass": 29, "i3calibr": 29, "indicesfor": 29, "boost": 30, "enum": 30, "fn": 30, "ensur": [30, 39, 95, 98, 100], "isn": 30, "return_discard": 30, "valid": [30, 40, 84], "ignor": 30, "mangl": 30, "take": [30, 35, 98], "mainli": 30, "cannot": 30, "trivial": 30, "doe": 30, "try": 30, "length": 30, "equival": 30, "its": 30, "like": [30, 98], "otherwis": 30, "itself": 30, "deem": 30, "wai": [30, 40, 98, 100], "represent": 30, "optic": 30, "found": 30, "parquetdataconvert": [31, 32], "sqlitedataconvert": [34, 35, 100], "construct_datafram": [34, 35], "is_pulse_map": [34, 35], "is_mc_tre": [34, 35], "database_exist": [34, 36], "database_table_exist": [34, 36], "run_sql_cod": [34, 36], "save_to_sql": [34, 36], "attach_index": [34, 36], "create_t": [34, 36], "create_table_and_save_to_sql": [34, 36], "db": 35, "databas": [35, 36, 38, 75, 82, 100], "max_table_s": 35, "maximum": [35, 84], "row": [35, 36], "given": [35, 82, 84], "exce": 35, "limit": 35, "creat": [35, 36, 98, 100], "any_pulsemap_is_non_empti": 35, "data_dict": 35, "empti": 35, "retriev": 35, "splitinicepuls": 35, "least": [35, 98, 100], "true": [35, 36, 75, 82], "becaus": [35, 39], "instead": 35, "alwai": 35, "panda": [35, 40, 82], "datafram": [35, 36, 40, 75, 82], "table_nam": [35, 36], "database_path": [36, 75, 82], "df": 36, "must": [36, 82, 98], "alreadi": [36, 100], "attach": 36, "queri": [36, 40], "column": [36, 75, 82], "default_typ": 36, "null": 36, "integer_primary_kei": 36, "event_no": [36, 40, 82], "NOT": 36, "integ": 36, "primari": 36, "Such": 36, "uniqu": [36, 38], "appropri": 36, "expect": [36, 40], "doesn": 36, "parquettosqliteconvert": [37, 38], "pairwise_shuffl": [37, 39], "stringselectionresolv": [37, 40], "parquet_path": 38, "mc_truth_tabl": 38, "excluded_field": 38, "assign": [38, 98], "id": 38, "everi": [38, 100], "field": [38, 76], "One": [38, 76], "choos": 38, "argument": [38, 82, 84], "exclude_field": 38, "database_nam": 38, "convers": [38, 100], "directori": [38, 75, 93], "rng": 39, "relat": [39, 93], "i3_list": [39, 93], "gcd_list": [39, 93], "shuffl": 39, "correpond": 39, "handi": 39, "even": 39, "files_list": 39, "gcd_shuffl": 39, "i3_shuffl": 39, "resolv": 40, "indic": [40, 84, 98], "seed": 40, "use_cach": 40, "datasetconfig": 40, "flexibl": 40, "defin": 40, "below": [40, 76, 82, 98, 100], "show": 40, "involv": 40, "cover": 40, "yml": [40, 84], "test": [40, 94, 98], "50000": 40, "ab": 40, "12": 40, "14": 40, "16": 40, "13": [40, 100], "10000": 40, "compat": 40, "syntax": 40, "mai": [40, 100], "also": 40, "specifi": [40, 76, 100], "fix": 40, "randomli": 40, "20": [40, 95], "graphnet_modul": [41, 42], "convnet": [45, 54], "dynedg": [45, 54], "dynedge_jinst": [45, 54], "dynedge_kaggle_tito": [45, 54], "edg": [45, 60], "node": [45, 60], "graph_definit": [45, 60], "config_updat": [74, 75], "weightfitt": [74, 75, 77, 82], "contourfitt": [74, 75], "read_entri": [74, 76], "plot_2d_contour": [74, 76], "plot_1d_contour": [74, 76], "contour": [75, 76], "config_path": 75, "new_config_path": 75, "dummy_sect": 75, "updat": 75, "temp": 75, "dummi": 75, "section": 75, "header": 75, "configupdat": 75, "programat": 75, "truth_tabl": [75, 82], "statistical_fit": 75, "weight": [75, 82, 100], "fit_weight": [75, 82], "config_outdir": 75, "weight_nam": [75, 82], "pisa_config_dict": 75, "add_to_databas": [75, 82], "flux": 75, "self": 75, "_database_path": 75, "statist": 75, "effect": [75, 98], "account": 75, "systemat": 75, "hypersurfac": 75, "chang": [75, 98], "assumpt": 75, "regard": 75, "fals": [75, 82], "two": 75, "pipeline_path": 75, "post_fix": 75, "model_nam": 75, "include_retro": 75, "fit_1d_contour": 75, "run_nam": 75, "config_dict": 75, "grid_siz": 75, "n_worker": 75, "theta23_minmax": 75, "36": 75, "54": 75, "dm31_minmax": 75, "3": [75, 76, 98, 100], "7": 75, "1d": [75, 76], "float": [75, 76], "fit_2d_contour": 75, "2d": [75, 76], "entri": [76, 84], "content": 76, "contour_data": 76, "xlim": 76, "4": 76, "6": 76, "ylim": 76, "0023799999999999997": 76, "0025499999999999997": 76, "chi2_critical_valu": 76, "width": 76, "height": 76, "path_to_pisa_fit_result": 76, "name_of_my_model_in_fit": 76, "legend": 76, "color": 76, "linestyl": 76, "style": [76, 98], "line": [76, 84], "upper": 76, "axi": 76, "605": 76, "critic": [76, 95], "chi2": 76, "90": 76, "cl": 76, "note": 76, "right": 76, "176": 76, "inch": 76, "388": 76, "706": 76, "abov": [76, 82, 100], "352": 76, "uniform": [77, 82], "bjoernlow": [77, 82], "produc": 82, "public": 82, "uniformweightfitt": 82, "bin": 82, "kwarg": [82, 95], "privat": 82, "_fit_weight": 82, "sql": 82, "desir": [82, 93], "space": 82, "np": 82, "log10": 82, "happen": 82, "addit": 82, "pass": [82, 98], "distribut": 82, "x_low": 82, "wherea": 82, "curv": 82, "base_config": [83, 85], "dataset_config": [83, 85], "model_config": [83, 85], "training_config": [83, 85], "argumentpars": [83, 84], "is_gcd_fil": [83, 93], "is_i3_fil": [83, 93], "has_extens": [83, 93], "find_i3_fil": [83, 93], "has_icecube_packag": [83, 94], "has_torch_packag": [83, 94], "has_pisa_packag": [83, 94], "requires_icecub": [83, 94], "repeatfilt": [83, 95], "consist": [84, 95, 98], "cli": 84, "present": [84, 93, 94], "pop_default": 84, "remov": 84, "usag": 84, "descript": 84, "command": [84, 100], "standard_argu": 84, "size": 84, "128": 84, "help": [84, 98], "home": [84, 100], "runner": 84, "local": 84, "lib": [84, 100], "python3": 84, "training_example_data_sqlit": 84, "earli": 84, "patienc": 84, "epoch": 84, "loss": 84, "after": 84, "gpu": [84, 100], "narg": 84, "max": 84, "50": 84, "example_energy_reconstruction_model": 84, "num": 84, "fetch": 84, "with_standard_argu": 84, "arg": [84, 95], "add": [84, 98, 100], "overwritten": 84, "system": [93, 100], "filenam": 93, "dir": 93, "search": 93, "test_funct": 94, "filter": 95, "out": [95, 98, 100], "repeat": 95, "messag": 95, "nb_repeats_allow": 95, "record": 95, "print": 95, "logrecord": 95, "class_nam": 95, "log_fold": 95, "clear": 95, "intuit": 95, "composit": 95, "rather": 95, "loggeradapt": 95, "chosen": 95, "avoid": [95, 98], "clash": 95, "pytorch_lightn": 95, "lightningmodul": 95, "setlevel": 95, "deleg": 95, "msg": 95, "error": [95, 98], "warn": 95, "info": [95, 100], "debug": 95, "warning_onc": 95, "exactli": 95, "onc": 95, "handler": 95, "file_handl": 95, "filehandl": 95, "stream_handl": 95, "streamhandl": 95, "api": 97, "To": [98, 100], "sure": [98, 100], "smooth": 98, "guidelin": 98, "guid": 98, "encourag": 98, "contributor": 98, "discuss": 98, "bug": 98, "anyth": 98, "you": [98, 100], "place": 98, "describ": 98, "altern": 98, "yourself": 98, "ownership": 98, "particular": 98, "activ": [98, 100], "transpar": 98, "prioriti": 98, "situat": 98, "lot": 98, "effort": 98, "go": 98, "turn": 98, "outsid": 98, "scope": 98, "solut": 98, "better": 98, "fork": 98, "repo": 98, "dedic": 98, "branch": [98, 100], "your": [98, 100], "repositori": 98, "graphdefinit": 98, "euclidean": 98, "definit": 98, "own": [98, 100], "team": 98, "accept": 98, "autom": 98, "review": 98, "pep8": 98, "docstr": 98, "googl": 98, "hint": 98, "clean": [98, 100], "see": [98, 100], "version": [98, 100], "8": [98, 100], "adher": 98, "pep": 98, "pylint": 98, "flake8": 98, "black": 98, "well": 98, "recommend": [98, 100], "mypi": 98, "pydocstyl": 98, "docformatt": 98, "commit": 98, "hook": 98, "instal": 98, "come": 98, "tag": [98, 100], "pip": [98, 100], "Then": 98, "everytim": 98, "pep257": 98, "static": 98, "concept": 98, "http": 98, "ljvmiranda921": 98, "io": 98, "notebook": 98, "2018": 98, "06": 98, "21": 98, "precommit": 98, "environ": 100, "virtual": 100, "anaconda": 100, "prove": 100, "instruct": 100, "setup": 100, "want": 100, "part": 100, "In": 100, "runtim": 100, "achiev": 100, "bash": 100, "shell": 100, "eval": 100, "cvmf": 100, "opensciencegrid": 100, "org": 100, "py3": 100, "v4": 100, "sh": 100, "rhel_7_x86_64": 100, "metaproject": 100, "v1": 100, "env": 100, "alia": 100, "script": 100, "With": 100, "now": 100, "light": 100, "extra": 100, "pytorch": 100, "geometr": 100, "just": 100, "won": 100, "later": 100, "don": 100, "r": 100, "torch_cpu": 100, "txt": 100, "cpu": 100, "torch_gpu": 100, "prefer": 100, "unix": 100, "git": 100, "clone": 100, "github": 100, "com": 100, "usernam": 100, "cd": 100, "conda": 100, "gcc_linux": 100, "64": 100, "gxx_linux": 100, "libgcc": 100, "cudatoolkit": 100, "11": 100, "forg": 100, "torch_maco": 100, "On": 100, "maco": 100, "box": 100, "compil": 100, "gcc": 100, "date": 100, "possibli": 100, "cuda": 100, "toolkit": 100, "recent": 100, "omit": 100, "newer": 100, "export": 100, "ld_library_path": 100, "anaconda3": 100, "miniconda3": 100, "bashrc": 100, "librari": 100, "access": 100, "so": 100, "re": 100, "intend": 100, "consid": 100, "rm": 100, "asogaard": 100, "latest": 100, "dc423315742c": 100, "01_icetrai": 100, "01_convert_i3_fil": 100, "py": 100, "2023": 100, "01": 100, "24": 100, "41": 100, "27": 100, "__init__": 100, "write": 100, "graphnet_20230124": 100, "134127": 100, "46": 100, "root": 100, "convert_i3_fil": 100, "ic86": 100, "thread": 100, "100": 100, "00": 100, "79": 100, "42": 100, "26": 100, "413": 100, "88it": 100, "specialis": 100, "ones": 100, "push": 100, "vx": 100}, "objects": {"": [[1, 0, 0, "-", "graphnet"]], "graphnet": [[2, 0, 0, "-", "constants"], [3, 0, 0, "-", "data"], [41, 0, 0, "-", "deployment"], [74, 0, 0, "-", "pisa"], [77, 0, 0, "-", "training"], [83, 0, 0, "-", "utilities"]], "graphnet.data": [[4, 0, 0, "-", "constants"], [5, 0, 0, "-", "dataconverter"], [14, 0, 0, "-", "extractors"], [31, 0, 0, "-", "parquet"], [34, 0, 0, "-", "sqlite"], [37, 0, 0, "-", "utilities"]], "graphnet.data.constants": [[4, 1, 1, "", "FEATURES"], [4, 1, 1, "", "TRUTH"]], "graphnet.data.constants.FEATURES": [[4, 2, 1, "", "DEEPCORE"], [4, 2, 1, "", "ICECUBE86"], [4, 2, 1, "", "KAGGLE"], [4, 2, 1, "", "PROMETHEUS"], [4, 2, 1, "", "UPGRADE"]], "graphnet.data.constants.TRUTH": [[4, 2, 1, "", "DEEPCORE"], [4, 2, 1, "", "ICECUBE86"], [4, 2, 1, "", "KAGGLE"], [4, 2, 1, "", "PROMETHEUS"], [4, 2, 1, "", "UPGRADE"]], "graphnet.data.dataconverter": [[5, 1, 1, "", "DataConverter"], [5, 1, 1, "", "FileSet"], [5, 5, 1, "", "cache_output_files"], [5, 5, 1, "", "init_global_index"]], "graphnet.data.dataconverter.DataConverter": [[5, 3, 1, "", "execute"], [5, 4, 1, "", "file_suffix"], [5, 3, 1, "", "get_map_function"], [5, 3, 1, "", "merge_files"], [5, 3, 1, "", "save_data"]], "graphnet.data.dataconverter.FileSet": [[5, 2, 1, "", "gcd_file"], [5, 2, 1, "", "i3_file"]], "graphnet.data.extractors": [[15, 0, 0, "-", "i3extractor"], [16, 0, 0, "-", "i3featureextractor"], [17, 0, 0, "-", "i3genericextractor"], [18, 0, 0, "-", "i3hybridrecoextractor"], [19, 0, 0, "-", "i3ntmuonlabelsextractor"], [20, 0, 0, "-", "i3particleextractor"], [21, 0, 0, "-", "i3pisaextractor"], [22, 0, 0, "-", "i3quesoextractor"], [23, 0, 0, "-", "i3retroextractor"], [24, 0, 0, "-", "i3splinempeextractor"], [25, 0, 0, "-", "i3truthextractor"], [26, 0, 0, "-", "i3tumextractor"], [27, 0, 0, "-", "utilities"]], "graphnet.data.extractors.i3extractor": [[15, 1, 1, "", "I3Extractor"], [15, 1, 1, "", "I3ExtractorCollection"]], "graphnet.data.extractors.i3extractor.I3Extractor": [[15, 4, 1, "", "name"], [15, 3, 1, "", "set_files"]], "graphnet.data.extractors.i3extractor.I3ExtractorCollection": [[15, 3, 1, "", "set_files"]], "graphnet.data.extractors.i3featureextractor": [[16, 1, 1, "", "I3FeatureExtractor"], [16, 1, 1, "", "I3FeatureExtractorIceCube86"], [16, 1, 1, "", "I3FeatureExtractorIceCubeDeepCore"], [16, 1, 1, "", "I3FeatureExtractorIceCubeUpgrade"], [16, 1, 1, "", "I3PulseNoiseTruthFlagIceCubeUpgrade"]], "graphnet.data.extractors.i3genericextractor": [[17, 1, 1, "", "I3GenericExtractor"]], "graphnet.data.extractors.i3hybridrecoextractor": [[18, 1, 1, "", "I3GalacticPlaneHybridRecoExtractor"]], "graphnet.data.extractors.i3ntmuonlabelsextractor": [[19, 1, 1, "", "I3NTMuonLabelExtractor"]], "graphnet.data.extractors.i3particleextractor": [[20, 1, 1, "", "I3ParticleExtractor"]], "graphnet.data.extractors.i3pisaextractor": [[21, 1, 1, "", "I3PISAExtractor"]], "graphnet.data.extractors.i3quesoextractor": [[22, 1, 1, "", "I3QUESOExtractor"]], "graphnet.data.extractors.i3retroextractor": [[23, 1, 1, "", "I3RetroExtractor"]], "graphnet.data.extractors.i3splinempeextractor": [[24, 1, 1, "", "I3SplineMPEICExtractor"]], "graphnet.data.extractors.i3truthextractor": [[25, 1, 1, "", "I3TruthExtractor"]], "graphnet.data.extractors.i3tumextractor": [[26, 1, 1, "", "I3TUMExtractor"]], "graphnet.data.extractors.utilities": [[28, 0, 0, "-", "collections"], [29, 0, 0, "-", "frames"], [30, 0, 0, "-", "types"]], "graphnet.data.extractors.utilities.collections": [[28, 5, 1, "", "flatten_nested_dictionary"], [28, 5, 1, "", "serialise"], [28, 5, 1, "", "transpose_list_of_dicts"]], "graphnet.data.extractors.utilities.frames": [[29, 5, 1, "", "frame_is_montecarlo"], [29, 5, 1, "", "frame_is_noise"], [29, 5, 1, "", "get_om_keys_and_pulseseries"]], "graphnet.data.extractors.utilities.types": [[30, 5, 1, "", "break_cyclic_recursion"], [30, 5, 1, "", "cast_object_to_pure_python"], [30, 5, 1, "", "cast_pulse_series_to_pure_python"], [30, 5, 1, "", "get_member_variables"], [30, 5, 1, "", "is_boost_class"], [30, 5, 1, "", "is_boost_enum"], [30, 5, 1, "", "is_icecube_class"], [30, 5, 1, "", "is_method"], [30, 5, 1, "", "is_type"]], "graphnet.data.parquet": [[32, 0, 0, "-", "parquet_dataconverter"]], "graphnet.data.parquet.parquet_dataconverter": [[32, 1, 1, "", "ParquetDataConverter"]], "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter": [[32, 2, 1, "", "file_suffix"], [32, 3, 1, "", "merge_files"], [32, 3, 1, "", "save_data"]], "graphnet.data.sqlite": [[35, 0, 0, "-", "sqlite_dataconverter"], [36, 0, 0, "-", "sqlite_utilities"]], "graphnet.data.sqlite.sqlite_dataconverter": [[35, 1, 1, "", "SQLiteDataConverter"], [35, 5, 1, "", "construct_dataframe"], [35, 5, 1, "", "is_mc_tree"], [35, 5, 1, "", "is_pulse_map"]], "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter": [[35, 3, 1, "", "any_pulsemap_is_non_empty"], [35, 2, 1, "", "file_suffix"], [35, 3, 1, "", "merge_files"], [35, 3, 1, "", "save_data"]], "graphnet.data.sqlite.sqlite_utilities": [[36, 5, 1, "", "attach_index"], [36, 5, 1, "", "create_table"], [36, 5, 1, "", "create_table_and_save_to_sql"], [36, 5, 1, "", "database_exists"], [36, 5, 1, "", "database_table_exists"], [36, 5, 1, "", "run_sql_code"], [36, 5, 1, "", "save_to_sql"]], "graphnet.data.utilities": [[38, 0, 0, "-", "parquet_to_sqlite"], [39, 0, 0, "-", "random"], [40, 0, 0, "-", "string_selection_resolver"]], "graphnet.data.utilities.parquet_to_sqlite": [[38, 1, 1, "", "ParquetToSQLiteConverter"]], "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter": [[38, 3, 1, "", "run"]], "graphnet.data.utilities.random": [[39, 5, 1, "", "pairwise_shuffle"]], "graphnet.data.utilities.string_selection_resolver": [[40, 1, 1, "", "StringSelectionResolver"]], "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver": [[40, 3, 1, "", "resolve"]], "graphnet.pisa": [[75, 0, 0, "-", "fitting"], [76, 0, 0, "-", "plotting"]], "graphnet.pisa.fitting": [[75, 1, 1, "", "ContourFitter"], [75, 1, 1, "", "WeightFitter"], [75, 5, 1, "", "config_updater"]], "graphnet.pisa.fitting.ContourFitter": [[75, 3, 1, "", "fit_1d_contour"], [75, 3, 1, "", "fit_2d_contour"]], "graphnet.pisa.fitting.WeightFitter": [[75, 3, 1, "", "fit_weights"]], "graphnet.pisa.plotting": [[76, 5, 1, "", "plot_1D_contour"], [76, 5, 1, "", "plot_2D_contour"], [76, 5, 1, "", "read_entry"]], "graphnet.training": [[82, 0, 0, "-", "weight_fitting"]], "graphnet.training.weight_fitting": [[82, 1, 1, "", "BjoernLow"], [82, 1, 1, "", "Uniform"], [82, 1, 1, "", "WeightFitter"]], "graphnet.training.weight_fitting.WeightFitter": [[82, 3, 1, "", "fit"]], "graphnet.utilities": [[84, 0, 0, "-", "argparse"], [92, 0, 0, "-", "decorators"], [93, 0, 0, "-", "filesys"], [94, 0, 0, "-", "imports"], [95, 0, 0, "-", "logging"]], "graphnet.utilities.argparse": [[84, 1, 1, "", "ArgumentParser"], [84, 1, 1, "", "Options"]], "graphnet.utilities.argparse.ArgumentParser": [[84, 2, 1, "", "standard_arguments"], [84, 3, 1, "", "with_standard_arguments"]], "graphnet.utilities.argparse.Options": [[84, 3, 1, "", "contains"], [84, 3, 1, "", "pop_default"]], "graphnet.utilities.filesys": [[93, 5, 1, "", "find_i3_files"], [93, 5, 1, "", "has_extension"], [93, 5, 1, "", "is_gcd_file"], [93, 5, 1, "", "is_i3_file"]], "graphnet.utilities.imports": [[94, 5, 1, "", "has_icecube_package"], [94, 5, 1, "", "has_pisa_package"], [94, 5, 1, "", "has_torch_package"], [94, 5, 1, "", "requires_icecube"]], "graphnet.utilities.logging": [[95, 1, 1, "", "Logger"], [95, 1, 1, "", "RepeatFilter"]], "graphnet.utilities.logging.Logger": [[95, 3, 1, "", "critical"], [95, 3, 1, "", "debug"], [95, 3, 1, "", "error"], [95, 4, 1, "", "file_handlers"], [95, 4, 1, "", "handlers"], [95, 3, 1, "", "info"], [95, 3, 1, "", "setLevel"], [95, 4, 1, "", "stream_handlers"], [95, 3, 1, "", "warning"], [95, 3, 1, "", "warning_once"]], "graphnet.utilities.logging.RepeatFilter": [[95, 3, 1, "", "filter"], [95, 2, 1, "", "nb_repeats_allowed"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:attribute", "3": "py:method", "4": "py:property", "5": "py:function"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "attribute", "Python attribute"], "3": ["py", "method", "Python method"], "4": ["py", "property", "Python property"], "5": ["py", "function", "Python function"]}, "titleterms": {"about": [0, 99], "impact": [0, 99], "usag": [0, 99], "acknowledg": [0, 99], "api": 1, "constant": [2, 4], "data": 3, "dataconvert": 5, "dataload": 6, "dataset": [7, 8], "parquet": [9, 31], "parquet_dataset": 10, "sqlite": [11, 34], "sqlite_dataset": 12, "sqlite_dataset_perturb": 13, "extractor": 14, "i3extractor": 15, "i3featureextractor": 16, "i3genericextractor": 17, "i3hybridrecoextractor": 18, "i3ntmuonlabelsextractor": 19, "i3particleextractor": 20, "i3pisaextractor": 21, "i3quesoextractor": 22, "i3retroextractor": 23, "i3splinempeextractor": 24, "i3truthextractor": 25, "i3tumextractor": 26, "util": [27, 37, 73, 81, 83], "collect": 28, "frame": 29, "type": 30, "parquet_dataconvert": 32, "pipelin": 33, "sqlite_dataconvert": 35, "sqlite_util": 36, "parquet_to_sqlit": 38, "random": 39, "string_selection_resolv": 40, "deploy": [41, 43], "i3modul": 42, "graphnet_modul": 44, "model": [45, 67], "coarsen": 46, "compon": 47, "layer": 48, "pool": 49, "detector": [50, 51], "icecub": 52, "prometheu": 53, "gnn": [54, 59], "convnet": 55, "dynedg": 56, "dynedge_jinst": 57, "dynedge_kaggle_tito": 58, "graph": [60, 64], "edg": [61, 62], "graph_definit": 63, "node": [65, 66], "standard_model": 68, "task": [69, 72], "classif": 70, "reconstruct": 71, "pisa": 74, "fit": 75, "plot": 76, "train": 77, "callback": 78, "label": 79, "loss_funct": 80, "weight_fit": 82, "argpars": 84, "config": 85, "base_config": 86, "configur": 87, "dataset_config": 88, "model_config": 89, "pars": 90, "training_config": 91, "decor": 92, "filesi": 93, "import": 94, "log": 95, "math": 96, "src": 97, "contribut": 98, "github": 98, "issu": 98, "pull": 98, "request": 98, "convent": 98, "code": 98, "qualiti": 98, "instal": 100, "icetrai": 100, "stand": 100, "alon": 100, "run": 100, "docker": 100}, "envversion": {"sphinx.domains.c": 3, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 9, "sphinx.domains.index": 1, "sphinx.domains.javascript": 3, "sphinx.domains.math": 2, "sphinx.domains.python": 4, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx.ext.todo": 2, "sphinx.ext.viewcode": 1, "sphinx": 60}, "alltitles": {"About": [[0, "about"], [99, "about"]], "Impact": [[0, "impact"], [99, "impact"]], "Usage": [[0, "usage"], [99, "usage"]], "Acknowledgements": [[0, "acknowledgements"], [99, "acknowledgements"]], "API": [[1, "module-graphnet"]], "constants": [[2, "module-graphnet.constants"], [4, "module-graphnet.data.constants"]], "data": [[3, "module-graphnet.data"]], "dataconverter": [[5, "module-graphnet.data.dataconverter"]], "dataloader": [[6, "dataloader"]], "dataset": [[7, "dataset"], [8, "dataset"]], "parquet": [[9, "parquet"], [31, "module-graphnet.data.parquet"]], "parquet_dataset": [[10, "parquet-dataset"]], "sqlite": [[11, "sqlite"], [34, "module-graphnet.data.sqlite"]], "sqlite_dataset": [[12, "sqlite-dataset"]], "sqlite_dataset_perturbed": [[13, "sqlite-dataset-perturbed"]], "extractors": [[14, "module-graphnet.data.extractors"]], "i3extractor": [[15, "module-graphnet.data.extractors.i3extractor"]], "i3featureextractor": [[16, "module-graphnet.data.extractors.i3featureextractor"]], "i3genericextractor": [[17, "module-graphnet.data.extractors.i3genericextractor"]], "i3hybridrecoextractor": [[18, "module-graphnet.data.extractors.i3hybridrecoextractor"]], "i3ntmuonlabelsextractor": [[19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"]], "i3particleextractor": [[20, "module-graphnet.data.extractors.i3particleextractor"]], "i3pisaextractor": [[21, "module-graphnet.data.extractors.i3pisaextractor"]], "i3quesoextractor": [[22, "module-graphnet.data.extractors.i3quesoextractor"]], "i3retroextractor": [[23, "module-graphnet.data.extractors.i3retroextractor"]], "i3splinempeextractor": [[24, "module-graphnet.data.extractors.i3splinempeextractor"]], "i3truthextractor": [[25, "module-graphnet.data.extractors.i3truthextractor"]], "i3tumextractor": [[26, "module-graphnet.data.extractors.i3tumextractor"]], "utilities": [[27, "module-graphnet.data.extractors.utilities"], [37, "module-graphnet.data.utilities"], [83, "module-graphnet.utilities"]], "collections": [[28, "module-graphnet.data.extractors.utilities.collections"]], "frames": [[29, "module-graphnet.data.extractors.utilities.frames"]], "types": [[30, "module-graphnet.data.extractors.utilities.types"]], "parquet_dataconverter": [[32, "module-graphnet.data.parquet.parquet_dataconverter"]], "pipeline": [[33, "pipeline"]], "sqlite_dataconverter": [[35, "module-graphnet.data.sqlite.sqlite_dataconverter"]], "sqlite_utilities": [[36, "module-graphnet.data.sqlite.sqlite_utilities"]], "parquet_to_sqlite": [[38, "module-graphnet.data.utilities.parquet_to_sqlite"]], "random": [[39, "module-graphnet.data.utilities.random"]], "string_selection_resolver": [[40, "module-graphnet.data.utilities.string_selection_resolver"]], "deployment": [[41, "module-graphnet.deployment"]], "i3modules": [[42, "i3modules"]], "deployer": [[43, "deployer"]], "graphnet_module": [[44, "graphnet-module"]], "models": [[45, "models"]], "coarsening": [[46, "coarsening"]], "components": [[47, "components"]], "layers": [[48, "layers"]], "pool": [[49, "pool"]], "detector": [[50, "detector"], [51, "detector"]], "icecube": [[52, "icecube"]], "prometheus": [[53, "prometheus"]], "gnn": [[54, "gnn"], [59, "gnn"]], "convnet": [[55, "convnet"]], "dynedge": [[56, "dynedge"]], "dynedge_jinst": [[57, "dynedge-jinst"]], "dynedge_kaggle_tito": [[58, "dynedge-kaggle-tito"]], "graphs": [[60, "graphs"], [64, "graphs"]], "edges": [[61, "edges"], [62, "edges"]], "graph_definition": [[63, "graph-definition"]], "nodes": [[65, "nodes"], [66, "nodes"]], "model": [[67, "model"]], "standard_model": [[68, "standard-model"]], "task": [[69, "task"], [72, "task"]], "classification": [[70, "classification"]], "reconstruction": [[71, "reconstruction"]], "utils": [[73, "utils"], [81, "utils"]], "pisa": [[74, "module-graphnet.pisa"]], "fitting": [[75, "module-graphnet.pisa.fitting"]], "plotting": [[76, "module-graphnet.pisa.plotting"]], "training": [[77, "module-graphnet.training"]], "callbacks": [[78, "callbacks"]], "labels": [[79, "labels"]], "loss_functions": [[80, "loss-functions"]], "weight_fitting": [[82, "module-graphnet.training.weight_fitting"]], "argparse": [[84, "module-graphnet.utilities.argparse"]], "config": [[85, "config"]], "base_config": [[86, "base-config"]], "configurable": [[87, "configurable"]], "dataset_config": [[88, "dataset-config"]], "model_config": [[89, "model-config"]], "parsing": [[90, "parsing"]], "training_config": [[91, "training-config"]], "decorators": [[92, "module-graphnet.utilities.decorators"]], "filesys": [[93, "module-graphnet.utilities.filesys"]], "imports": [[94, "module-graphnet.utilities.imports"]], "logging": [[95, "module-graphnet.utilities.logging"]], "maths": [[96, "maths"]], "src": [[97, "src"]], "Contribute": [[98, "contribute"]], "GitHub issues": [[98, "github-issues"]], "Pull requests": [[98, "pull-requests"]], "Conventions": [[98, "conventions"]], "Code quality": [[98, "code-quality"]], "Install": [[100, "install"]], "Installing with IceTray": [[100, "installing-with-icetray"]], "Installing stand-alone": [[100, "installing-stand-alone"]], "Running in Docker": [[100, "running-in-docker"]]}, "indexentries": {"graphnet": [[1, "module-graphnet"]], "module": [[1, "module-graphnet"], [2, "module-graphnet.constants"], [3, "module-graphnet.data"], [4, "module-graphnet.data.constants"], [5, "module-graphnet.data.dataconverter"], [14, "module-graphnet.data.extractors"], [15, "module-graphnet.data.extractors.i3extractor"], [16, "module-graphnet.data.extractors.i3featureextractor"], [17, "module-graphnet.data.extractors.i3genericextractor"], [18, "module-graphnet.data.extractors.i3hybridrecoextractor"], [19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"], [20, "module-graphnet.data.extractors.i3particleextractor"], [21, "module-graphnet.data.extractors.i3pisaextractor"], [22, "module-graphnet.data.extractors.i3quesoextractor"], [23, "module-graphnet.data.extractors.i3retroextractor"], [24, "module-graphnet.data.extractors.i3splinempeextractor"], [25, "module-graphnet.data.extractors.i3truthextractor"], [26, "module-graphnet.data.extractors.i3tumextractor"], [27, "module-graphnet.data.extractors.utilities"], [28, "module-graphnet.data.extractors.utilities.collections"], [29, "module-graphnet.data.extractors.utilities.frames"], [30, "module-graphnet.data.extractors.utilities.types"], [31, "module-graphnet.data.parquet"], [32, "module-graphnet.data.parquet.parquet_dataconverter"], [34, "module-graphnet.data.sqlite"], [35, "module-graphnet.data.sqlite.sqlite_dataconverter"], [36, "module-graphnet.data.sqlite.sqlite_utilities"], [37, "module-graphnet.data.utilities"], [38, "module-graphnet.data.utilities.parquet_to_sqlite"], [39, "module-graphnet.data.utilities.random"], [40, "module-graphnet.data.utilities.string_selection_resolver"], [41, "module-graphnet.deployment"], [74, "module-graphnet.pisa"], [75, "module-graphnet.pisa.fitting"], [76, "module-graphnet.pisa.plotting"], [77, "module-graphnet.training"], [82, "module-graphnet.training.weight_fitting"], [83, "module-graphnet.utilities"], [84, "module-graphnet.utilities.argparse"], [92, "module-graphnet.utilities.decorators"], [93, "module-graphnet.utilities.filesys"], [94, "module-graphnet.utilities.imports"], [95, "module-graphnet.utilities.logging"]], "graphnet.constants": [[2, "module-graphnet.constants"]], "graphnet.data": [[3, "module-graphnet.data"]], "deepcore (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.DEEPCORE"]], "deepcore (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.DEEPCORE"]], "features (class in graphnet.data.constants)": [[4, "graphnet.data.constants.FEATURES"]], "icecube86 (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.ICECUBE86"]], "icecube86 (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.ICECUBE86"]], "kaggle (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.KAGGLE"]], "kaggle (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.KAGGLE"]], "prometheus (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.PROMETHEUS"]], "prometheus (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.PROMETHEUS"]], "truth (class in graphnet.data.constants)": [[4, "graphnet.data.constants.TRUTH"]], "upgrade (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.UPGRADE"]], "upgrade (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.UPGRADE"]], "graphnet.data.constants": [[4, "module-graphnet.data.constants"]], "dataconverter (class in graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.DataConverter"]], "fileset (class in graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.FileSet"]], "cache_output_files() (in module graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.cache_output_files"]], "execute() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.execute"]], "file_suffix (graphnet.data.dataconverter.dataconverter property)": [[5, "graphnet.data.dataconverter.DataConverter.file_suffix"]], "gcd_file (graphnet.data.dataconverter.fileset attribute)": [[5, "graphnet.data.dataconverter.FileSet.gcd_file"]], "get_map_function() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.get_map_function"]], "graphnet.data.dataconverter": [[5, "module-graphnet.data.dataconverter"]], "i3_file (graphnet.data.dataconverter.fileset attribute)": [[5, "graphnet.data.dataconverter.FileSet.i3_file"]], "init_global_index() (in module graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.init_global_index"]], "merge_files() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.merge_files"]], "save_data() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.save_data"]], "graphnet.data.extractors": [[14, "module-graphnet.data.extractors"]], "i3extractor (class in graphnet.data.extractors.i3extractor)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor"]], "i3extractorcollection (class in graphnet.data.extractors.i3extractor)": [[15, "graphnet.data.extractors.i3extractor.I3ExtractorCollection"]], "graphnet.data.extractors.i3extractor": [[15, "module-graphnet.data.extractors.i3extractor"]], "name (graphnet.data.extractors.i3extractor.i3extractor property)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor.name"]], "set_files() (graphnet.data.extractors.i3extractor.i3extractor method)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor.set_files"]], "set_files() (graphnet.data.extractors.i3extractor.i3extractorcollection method)": [[15, "graphnet.data.extractors.i3extractor.I3ExtractorCollection.set_files"]], "i3featureextractor (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"]], "i3featureextractoricecube86 (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCube86"]], "i3featureextractoricecubedeepcore (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCubeDeepCore"]], "i3featureextractoricecubeupgrade (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCubeUpgrade"]], "i3pulsenoisetruthflagicecubeupgrade (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3PulseNoiseTruthFlagIceCubeUpgrade"]], "graphnet.data.extractors.i3featureextractor": [[16, "module-graphnet.data.extractors.i3featureextractor"]], "i3genericextractor (class in graphnet.data.extractors.i3genericextractor)": [[17, "graphnet.data.extractors.i3genericextractor.I3GenericExtractor"]], "graphnet.data.extractors.i3genericextractor": [[17, "module-graphnet.data.extractors.i3genericextractor"]], "i3galacticplanehybridrecoextractor (class in graphnet.data.extractors.i3hybridrecoextractor)": [[18, "graphnet.data.extractors.i3hybridrecoextractor.I3GalacticPlaneHybridRecoExtractor"]], "graphnet.data.extractors.i3hybridrecoextractor": [[18, "module-graphnet.data.extractors.i3hybridrecoextractor"]], "i3ntmuonlabelextractor (class in graphnet.data.extractors.i3ntmuonlabelsextractor)": [[19, "graphnet.data.extractors.i3ntmuonlabelsextractor.I3NTMuonLabelExtractor"]], "graphnet.data.extractors.i3ntmuonlabelsextractor": [[19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"]], "i3particleextractor (class in graphnet.data.extractors.i3particleextractor)": [[20, "graphnet.data.extractors.i3particleextractor.I3ParticleExtractor"]], "graphnet.data.extractors.i3particleextractor": [[20, "module-graphnet.data.extractors.i3particleextractor"]], "i3pisaextractor (class in graphnet.data.extractors.i3pisaextractor)": [[21, "graphnet.data.extractors.i3pisaextractor.I3PISAExtractor"]], "graphnet.data.extractors.i3pisaextractor": [[21, "module-graphnet.data.extractors.i3pisaextractor"]], "i3quesoextractor (class in graphnet.data.extractors.i3quesoextractor)": [[22, "graphnet.data.extractors.i3quesoextractor.I3QUESOExtractor"]], "graphnet.data.extractors.i3quesoextractor": [[22, "module-graphnet.data.extractors.i3quesoextractor"]], "i3retroextractor (class in graphnet.data.extractors.i3retroextractor)": [[23, "graphnet.data.extractors.i3retroextractor.I3RetroExtractor"]], "graphnet.data.extractors.i3retroextractor": [[23, "module-graphnet.data.extractors.i3retroextractor"]], "i3splinempeicextractor (class in graphnet.data.extractors.i3splinempeextractor)": [[24, "graphnet.data.extractors.i3splinempeextractor.I3SplineMPEICExtractor"]], "graphnet.data.extractors.i3splinempeextractor": [[24, "module-graphnet.data.extractors.i3splinempeextractor"]], "i3truthextractor (class in graphnet.data.extractors.i3truthextractor)": [[25, "graphnet.data.extractors.i3truthextractor.I3TruthExtractor"]], "graphnet.data.extractors.i3truthextractor": [[25, "module-graphnet.data.extractors.i3truthextractor"]], "i3tumextractor (class in graphnet.data.extractors.i3tumextractor)": [[26, "graphnet.data.extractors.i3tumextractor.I3TUMExtractor"]], "graphnet.data.extractors.i3tumextractor": [[26, "module-graphnet.data.extractors.i3tumextractor"]], "graphnet.data.extractors.utilities": [[27, "module-graphnet.data.extractors.utilities"]], "flatten_nested_dictionary() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.flatten_nested_dictionary"]], "graphnet.data.extractors.utilities.collections": [[28, "module-graphnet.data.extractors.utilities.collections"]], "serialise() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.serialise"]], "transpose_list_of_dicts() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.transpose_list_of_dicts"]], "frame_is_montecarlo() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.frame_is_montecarlo"]], "frame_is_noise() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.frame_is_noise"]], "get_om_keys_and_pulseseries() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.get_om_keys_and_pulseseries"]], "graphnet.data.extractors.utilities.frames": [[29, "module-graphnet.data.extractors.utilities.frames"]], "break_cyclic_recursion() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.break_cyclic_recursion"]], "cast_object_to_pure_python() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.cast_object_to_pure_python"]], "cast_pulse_series_to_pure_python() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.cast_pulse_series_to_pure_python"]], "get_member_variables() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.get_member_variables"]], "graphnet.data.extractors.utilities.types": [[30, "module-graphnet.data.extractors.utilities.types"]], "is_boost_class() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_boost_class"]], "is_boost_enum() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_boost_enum"]], "is_icecube_class() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_icecube_class"]], "is_method() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_method"]], "is_type() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_type"]], "graphnet.data.parquet": [[31, "module-graphnet.data.parquet"]], "parquetdataconverter (class in graphnet.data.parquet.parquet_dataconverter)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter"]], "file_suffix (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter attribute)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.file_suffix"]], "graphnet.data.parquet.parquet_dataconverter": [[32, "module-graphnet.data.parquet.parquet_dataconverter"]], "merge_files() (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter method)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.merge_files"]], "save_data() (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter method)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.save_data"]], "graphnet.data.sqlite": [[34, "module-graphnet.data.sqlite"]], "sqlitedataconverter (class in graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter"]], "any_pulsemap_is_non_empty() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.any_pulsemap_is_non_empty"]], "construct_dataframe() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.construct_dataframe"]], "file_suffix (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter attribute)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.file_suffix"]], "graphnet.data.sqlite.sqlite_dataconverter": [[35, "module-graphnet.data.sqlite.sqlite_dataconverter"]], "is_mc_tree() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.is_mc_tree"]], "is_pulse_map() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.is_pulse_map"]], "merge_files() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.merge_files"]], "save_data() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.save_data"]], "attach_index() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.attach_index"]], "create_table() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.create_table"]], "create_table_and_save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.create_table_and_save_to_sql"]], "database_exists() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.database_exists"]], "database_table_exists() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.database_table_exists"]], "graphnet.data.sqlite.sqlite_utilities": [[36, "module-graphnet.data.sqlite.sqlite_utilities"]], "run_sql_code() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.run_sql_code"]], "save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.save_to_sql"]], "graphnet.data.utilities": [[37, "module-graphnet.data.utilities"]], "parquettosqliteconverter (class in graphnet.data.utilities.parquet_to_sqlite)": [[38, "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter"]], "graphnet.data.utilities.parquet_to_sqlite": [[38, "module-graphnet.data.utilities.parquet_to_sqlite"]], "run() (graphnet.data.utilities.parquet_to_sqlite.parquettosqliteconverter method)": [[38, "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter.run"]], "graphnet.data.utilities.random": [[39, "module-graphnet.data.utilities.random"]], "pairwise_shuffle() (in module graphnet.data.utilities.random)": [[39, "graphnet.data.utilities.random.pairwise_shuffle"]], "stringselectionresolver (class in graphnet.data.utilities.string_selection_resolver)": [[40, "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver"]], "graphnet.data.utilities.string_selection_resolver": [[40, "module-graphnet.data.utilities.string_selection_resolver"]], "resolve() (graphnet.data.utilities.string_selection_resolver.stringselectionresolver method)": [[40, "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver.resolve"]], "graphnet.deployment": [[41, "module-graphnet.deployment"]], "graphnet.pisa": [[74, "module-graphnet.pisa"]], "contourfitter (class in graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.ContourFitter"]], "weightfitter (class in graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.WeightFitter"]], "config_updater() (in module graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.config_updater"]], "fit_1d_contour() (graphnet.pisa.fitting.contourfitter method)": [[75, "graphnet.pisa.fitting.ContourFitter.fit_1d_contour"]], "fit_2d_contour() (graphnet.pisa.fitting.contourfitter method)": [[75, "graphnet.pisa.fitting.ContourFitter.fit_2d_contour"]], "fit_weights() (graphnet.pisa.fitting.weightfitter method)": [[75, "graphnet.pisa.fitting.WeightFitter.fit_weights"]], "graphnet.pisa.fitting": [[75, "module-graphnet.pisa.fitting"]], "graphnet.pisa.plotting": [[76, "module-graphnet.pisa.plotting"]], "plot_1d_contour() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.plot_1D_contour"]], "plot_2d_contour() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.plot_2D_contour"]], "read_entry() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.read_entry"]], "graphnet.training": [[77, "module-graphnet.training"]], "bjoernlow (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.BjoernLow"]], "uniform (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.Uniform"]], "weightfitter (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.WeightFitter"]], "fit() (graphnet.training.weight_fitting.weightfitter method)": [[82, "graphnet.training.weight_fitting.WeightFitter.fit"]], "graphnet.training.weight_fitting": [[82, "module-graphnet.training.weight_fitting"]], "graphnet.utilities": [[83, "module-graphnet.utilities"]], "argumentparser (class in graphnet.utilities.argparse)": [[84, "graphnet.utilities.argparse.ArgumentParser"]], "options (class in graphnet.utilities.argparse)": [[84, "graphnet.utilities.argparse.Options"]], "contains() (graphnet.utilities.argparse.options method)": [[84, "graphnet.utilities.argparse.Options.contains"]], "graphnet.utilities.argparse": [[84, "module-graphnet.utilities.argparse"]], "pop_default() (graphnet.utilities.argparse.options method)": [[84, "graphnet.utilities.argparse.Options.pop_default"]], "standard_arguments (graphnet.utilities.argparse.argumentparser attribute)": [[84, "graphnet.utilities.argparse.ArgumentParser.standard_arguments"]], "with_standard_arguments() (graphnet.utilities.argparse.argumentparser method)": [[84, "graphnet.utilities.argparse.ArgumentParser.with_standard_arguments"]], "graphnet.utilities.decorators": [[92, "module-graphnet.utilities.decorators"]], "find_i3_files() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.find_i3_files"]], "graphnet.utilities.filesys": [[93, "module-graphnet.utilities.filesys"]], "has_extension() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.has_extension"]], "is_gcd_file() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.is_gcd_file"]], "is_i3_file() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.is_i3_file"]], "graphnet.utilities.imports": [[94, "module-graphnet.utilities.imports"]], "has_icecube_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_icecube_package"]], "has_pisa_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_pisa_package"]], "has_torch_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_torch_package"]], "requires_icecube() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.requires_icecube"]], "logger (class in graphnet.utilities.logging)": [[95, "graphnet.utilities.logging.Logger"]], "repeatfilter (class in graphnet.utilities.logging)": [[95, "graphnet.utilities.logging.RepeatFilter"]], "critical() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.critical"]], "debug() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.debug"]], "error() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.error"]], "file_handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.file_handlers"]], "filter() (graphnet.utilities.logging.repeatfilter method)": [[95, "graphnet.utilities.logging.RepeatFilter.filter"]], "graphnet.utilities.logging": [[95, "module-graphnet.utilities.logging"]], "handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.handlers"]], "info() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.info"]], "nb_repeats_allowed (graphnet.utilities.logging.repeatfilter attribute)": [[95, "graphnet.utilities.logging.RepeatFilter.nb_repeats_allowed"]], "setlevel() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.setLevel"]], "stream_handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.stream_handlers"]], "warning() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.warning"]], "warning_once() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.warning_once"]]}})
\ No newline at end of file
+Search.setIndex({"docnames": ["about", "api/graphnet", "api/graphnet.constants", "api/graphnet.data", "api/graphnet.data.constants", "api/graphnet.data.dataconverter", "api/graphnet.data.dataloader", "api/graphnet.data.dataset", "api/graphnet.data.dataset.dataset", "api/graphnet.data.dataset.parquet", "api/graphnet.data.dataset.parquet.parquet_dataset", "api/graphnet.data.dataset.sqlite", "api/graphnet.data.dataset.sqlite.sqlite_dataset", "api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed", "api/graphnet.data.extractors", "api/graphnet.data.extractors.i3extractor", "api/graphnet.data.extractors.i3featureextractor", "api/graphnet.data.extractors.i3genericextractor", "api/graphnet.data.extractors.i3hybridrecoextractor", "api/graphnet.data.extractors.i3ntmuonlabelsextractor", "api/graphnet.data.extractors.i3particleextractor", "api/graphnet.data.extractors.i3pisaextractor", "api/graphnet.data.extractors.i3quesoextractor", "api/graphnet.data.extractors.i3retroextractor", "api/graphnet.data.extractors.i3splinempeextractor", "api/graphnet.data.extractors.i3truthextractor", "api/graphnet.data.extractors.i3tumextractor", "api/graphnet.data.extractors.utilities", "api/graphnet.data.extractors.utilities.collections", "api/graphnet.data.extractors.utilities.frames", "api/graphnet.data.extractors.utilities.types", "api/graphnet.data.parquet", "api/graphnet.data.parquet.parquet_dataconverter", "api/graphnet.data.pipeline", "api/graphnet.data.sqlite", "api/graphnet.data.sqlite.sqlite_dataconverter", "api/graphnet.data.sqlite.sqlite_utilities", "api/graphnet.data.utilities", "api/graphnet.data.utilities.parquet_to_sqlite", "api/graphnet.data.utilities.random", "api/graphnet.data.utilities.string_selection_resolver", "api/graphnet.deployment", "api/graphnet.deployment.i3modules", "api/graphnet.deployment.i3modules.deployer", "api/graphnet.deployment.i3modules.graphnet_module", "api/graphnet.models", "api/graphnet.models.coarsening", "api/graphnet.models.components", "api/graphnet.models.components.layers", "api/graphnet.models.components.pool", "api/graphnet.models.detector", "api/graphnet.models.detector.detector", "api/graphnet.models.detector.icecube", "api/graphnet.models.detector.prometheus", "api/graphnet.models.gnn", "api/graphnet.models.gnn.convnet", "api/graphnet.models.gnn.dynedge", "api/graphnet.models.gnn.dynedge_jinst", "api/graphnet.models.gnn.dynedge_kaggle_tito", "api/graphnet.models.gnn.gnn", "api/graphnet.models.graphs", "api/graphnet.models.graphs.edges", "api/graphnet.models.graphs.edges.edges", "api/graphnet.models.graphs.graph_definition", "api/graphnet.models.graphs.graphs", "api/graphnet.models.graphs.nodes", "api/graphnet.models.graphs.nodes.nodes", "api/graphnet.models.model", "api/graphnet.models.standard_model", "api/graphnet.models.task", "api/graphnet.models.task.classification", "api/graphnet.models.task.reconstruction", "api/graphnet.models.task.task", "api/graphnet.models.utils", "api/graphnet.pisa", "api/graphnet.pisa.fitting", "api/graphnet.pisa.plotting", "api/graphnet.training", "api/graphnet.training.callbacks", "api/graphnet.training.labels", "api/graphnet.training.loss_functions", "api/graphnet.training.utils", "api/graphnet.training.weight_fitting", "api/graphnet.utilities", "api/graphnet.utilities.argparse", "api/graphnet.utilities.config", "api/graphnet.utilities.config.base_config", "api/graphnet.utilities.config.configurable", "api/graphnet.utilities.config.dataset_config", "api/graphnet.utilities.config.model_config", "api/graphnet.utilities.config.parsing", "api/graphnet.utilities.config.training_config", "api/graphnet.utilities.decorators", "api/graphnet.utilities.filesys", "api/graphnet.utilities.imports", "api/graphnet.utilities.logging", "api/graphnet.utilities.maths", "api/modules", "contribute", "index", "install"], "filenames": ["about.md", "api/graphnet.rst", "api/graphnet.constants.rst", "api/graphnet.data.rst", "api/graphnet.data.constants.rst", "api/graphnet.data.dataconverter.rst", "api/graphnet.data.dataloader.rst", "api/graphnet.data.dataset.rst", "api/graphnet.data.dataset.dataset.rst", "api/graphnet.data.dataset.parquet.rst", "api/graphnet.data.dataset.parquet.parquet_dataset.rst", "api/graphnet.data.dataset.sqlite.rst", "api/graphnet.data.dataset.sqlite.sqlite_dataset.rst", "api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.rst", "api/graphnet.data.extractors.rst", "api/graphnet.data.extractors.i3extractor.rst", "api/graphnet.data.extractors.i3featureextractor.rst", "api/graphnet.data.extractors.i3genericextractor.rst", "api/graphnet.data.extractors.i3hybridrecoextractor.rst", "api/graphnet.data.extractors.i3ntmuonlabelsextractor.rst", "api/graphnet.data.extractors.i3particleextractor.rst", "api/graphnet.data.extractors.i3pisaextractor.rst", "api/graphnet.data.extractors.i3quesoextractor.rst", "api/graphnet.data.extractors.i3retroextractor.rst", "api/graphnet.data.extractors.i3splinempeextractor.rst", "api/graphnet.data.extractors.i3truthextractor.rst", "api/graphnet.data.extractors.i3tumextractor.rst", "api/graphnet.data.extractors.utilities.rst", "api/graphnet.data.extractors.utilities.collections.rst", "api/graphnet.data.extractors.utilities.frames.rst", "api/graphnet.data.extractors.utilities.types.rst", "api/graphnet.data.parquet.rst", "api/graphnet.data.parquet.parquet_dataconverter.rst", "api/graphnet.data.pipeline.rst", "api/graphnet.data.sqlite.rst", "api/graphnet.data.sqlite.sqlite_dataconverter.rst", "api/graphnet.data.sqlite.sqlite_utilities.rst", "api/graphnet.data.utilities.rst", "api/graphnet.data.utilities.parquet_to_sqlite.rst", "api/graphnet.data.utilities.random.rst", "api/graphnet.data.utilities.string_selection_resolver.rst", "api/graphnet.deployment.rst", "api/graphnet.deployment.i3modules.rst", "api/graphnet.deployment.i3modules.deployer.rst", "api/graphnet.deployment.i3modules.graphnet_module.rst", "api/graphnet.models.rst", "api/graphnet.models.coarsening.rst", "api/graphnet.models.components.rst", "api/graphnet.models.components.layers.rst", "api/graphnet.models.components.pool.rst", "api/graphnet.models.detector.rst", "api/graphnet.models.detector.detector.rst", "api/graphnet.models.detector.icecube.rst", "api/graphnet.models.detector.prometheus.rst", "api/graphnet.models.gnn.rst", "api/graphnet.models.gnn.convnet.rst", "api/graphnet.models.gnn.dynedge.rst", "api/graphnet.models.gnn.dynedge_jinst.rst", "api/graphnet.models.gnn.dynedge_kaggle_tito.rst", "api/graphnet.models.gnn.gnn.rst", "api/graphnet.models.graphs.rst", "api/graphnet.models.graphs.edges.rst", "api/graphnet.models.graphs.edges.edges.rst", "api/graphnet.models.graphs.graph_definition.rst", "api/graphnet.models.graphs.graphs.rst", "api/graphnet.models.graphs.nodes.rst", "api/graphnet.models.graphs.nodes.nodes.rst", "api/graphnet.models.model.rst", "api/graphnet.models.standard_model.rst", "api/graphnet.models.task.rst", "api/graphnet.models.task.classification.rst", "api/graphnet.models.task.reconstruction.rst", "api/graphnet.models.task.task.rst", "api/graphnet.models.utils.rst", "api/graphnet.pisa.rst", "api/graphnet.pisa.fitting.rst", "api/graphnet.pisa.plotting.rst", "api/graphnet.training.rst", "api/graphnet.training.callbacks.rst", "api/graphnet.training.labels.rst", "api/graphnet.training.loss_functions.rst", "api/graphnet.training.utils.rst", "api/graphnet.training.weight_fitting.rst", "api/graphnet.utilities.rst", "api/graphnet.utilities.argparse.rst", "api/graphnet.utilities.config.rst", "api/graphnet.utilities.config.base_config.rst", "api/graphnet.utilities.config.configurable.rst", "api/graphnet.utilities.config.dataset_config.rst", "api/graphnet.utilities.config.model_config.rst", "api/graphnet.utilities.config.parsing.rst", "api/graphnet.utilities.config.training_config.rst", "api/graphnet.utilities.decorators.rst", "api/graphnet.utilities.filesys.rst", "api/graphnet.utilities.imports.rst", "api/graphnet.utilities.logging.rst", "api/graphnet.utilities.maths.rst", "api/modules.rst", "contribute.md", "index.rst", "install.md"], "titles": ["About", "API", "constants", "data", "constants", "dataconverter", "dataloader", "dataset", "dataset", "parquet", "parquet_dataset", "sqlite", "sqlite_dataset", "sqlite_dataset_perturbed", "extractors", "i3extractor", "i3featureextractor", "i3genericextractor", "i3hybridrecoextractor", "i3ntmuonlabelsextractor", "i3particleextractor", "i3pisaextractor", "i3quesoextractor", "i3retroextractor", "i3splinempeextractor", "i3truthextractor", "i3tumextractor", "utilities", "collections", "frames", "types", "parquet", "parquet_dataconverter", "pipeline", "sqlite", "sqlite_dataconverter", "sqlite_utilities", "utilities", "parquet_to_sqlite", "random", "string_selection_resolver", "deployment", "i3modules", "deployer", "graphnet_module", "models", "coarsening", "components", "layers", "pool", "detector", "detector", "icecube", "prometheus", "gnn", "convnet", "dynedge", "dynedge_jinst", "dynedge_kaggle_tito", "gnn", "graphs", "edges", "edges", "graph_definition", "graphs", "nodes", "nodes", "model", "standard_model", "task", "classification", "reconstruction", "task", "utils", "pisa", "fitting", "plotting", "training", "callbacks", "labels", "loss_functions", "utils", "weight_fitting", "utilities", "argparse", "config", "base_config", "configurable", "dataset_config", "model_config", "parsing", "training_config", "decorators", "filesys", "imports", "logging", "maths", "src", "Contribute", "About", "Install"], "terms": {"graphnet": [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 33, 35, 36, 37, 38, 39, 40, 41, 44, 45, 46, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96, 98, 99, 100], "i": [0, 1, 8, 10, 12, 13, 15, 17, 28, 29, 30, 35, 36, 39, 40, 44, 46, 49, 55, 56, 62, 66, 70, 71, 72, 73, 76, 78, 79, 80, 82, 84, 89, 90, 93, 94, 95, 98, 99, 100], "an": [0, 5, 30, 32, 33, 35, 40, 44, 63, 80, 93, 95, 98, 99, 100], "open": [0, 98, 99], "sourc": [0, 4, 5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 33, 35, 36, 38, 39, 40, 44, 46, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 75, 76, 78, 79, 80, 81, 82, 84, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96, 98, 99], "python": [0, 1, 5, 14, 15, 17, 28, 30, 98, 99, 100], "framework": [0, 99], "aim": [0, 1, 98, 99], "provid": [0, 1, 8, 10, 12, 13, 44, 45, 80, 98, 99, 100], "high": [0, 99], "qualiti": [0, 99], "user": [0, 45, 78, 99, 100], "friendli": [0, 99], "end": [0, 1, 5, 32, 35, 99], "function": [0, 5, 6, 8, 30, 36, 39, 44, 46, 49, 52, 53, 63, 67, 70, 71, 72, 73, 75, 76, 80, 81, 83, 88, 89, 90, 93, 94, 96, 99], "perform": [0, 46, 48, 49, 54, 56, 58, 68, 70, 71, 72, 99], "reconstruct": [0, 1, 16, 18, 19, 23, 24, 26, 33, 41, 45, 58, 69, 72, 99], "task": [0, 1, 45, 68, 70, 71, 80, 98, 99], "neutrino": [0, 1, 48, 58, 75, 99], "telescop": [0, 1, 99], "us": [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 15, 20, 25, 27, 28, 32, 33, 35, 36, 37, 38, 40, 41, 44, 45, 48, 49, 51, 56, 57, 58, 62, 63, 64, 67, 69, 70, 71, 72, 73, 75, 78, 79, 80, 82, 83, 84, 85, 86, 88, 89, 90, 91, 94, 95, 98, 99, 100], "graph": [0, 1, 6, 8, 10, 12, 13, 44, 45, 48, 49, 51, 61, 62, 63, 65, 66, 73, 79, 81, 98, 99], "neural": [0, 1, 99], "network": [0, 1, 55, 99], "gnn": [0, 1, 33, 45, 55, 56, 57, 58, 63, 68, 99, 100], "make": [0, 5, 82, 88, 89, 98, 99, 100], "fast": [0, 99, 100], "easi": [0, 99], "train": [0, 1, 7, 13, 40, 41, 44, 63, 68, 78, 79, 80, 81, 82, 84, 88, 89, 91, 97, 99, 100], "complex": [0, 45, 99], "model": [0, 1, 13, 41, 44, 46, 47, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 68, 69, 70, 71, 72, 73, 76, 77, 78, 80, 81, 84, 86, 88, 89, 91, 97, 99, 100], "can": [0, 1, 8, 10, 12, 13, 15, 17, 20, 38, 44, 49, 63, 75, 76, 82, 84, 86, 88, 89, 98, 99, 100], "event": [0, 1, 8, 10, 12, 13, 22, 36, 38, 40, 44, 49, 63, 70, 71, 72, 73, 75, 80, 82, 88, 99], "state": [0, 99], "art": [0, 99], "arbitrari": [0, 99], "detector": [0, 1, 25, 45, 52, 53, 63, 64, 66, 68, 99], "configur": [0, 1, 8, 45, 67, 68, 75, 83, 85, 86, 88, 89, 91, 95, 99], "infer": [0, 1, 33, 41, 44, 68, 70, 71, 72, 99, 100], "time": [0, 4, 36, 46, 49, 71, 95, 99, 100], "ar": [0, 1, 4, 5, 8, 10, 12, 13, 17, 30, 32, 35, 38, 40, 44, 49, 56, 58, 60, 61, 62, 63, 64, 65, 70, 75, 80, 82, 88, 89, 98, 99, 100], "order": [0, 28, 46, 73, 99], "magnitud": [0, 99], "faster": [0, 99], "than": [0, 6, 70, 71, 72, 81, 95, 99], "tradit": [0, 99], "techniqu": [0, 99], "common": [0, 1, 80, 86, 91, 92, 94, 99], "ml": [0, 1, 99], "develop": [0, 1, 98, 99, 100], "physicist": [0, 1, 99], "wish": [0, 98, 99], "tool": [0, 1, 99], "research": [0, 99], "By": [0, 38, 70, 71, 72, 99], "unit": [0, 5, 94, 98, 99], "both": [0, 17, 70, 71, 72, 76, 99], "group": [0, 5, 32, 35, 49, 99], "increas": [0, 78, 99], "longev": [0, 99], "usabl": [0, 99], "individu": [0, 5, 8, 10, 12, 13, 49, 56, 73, 99], "code": [0, 25, 36, 63, 88, 89, 99], "contribut": [0, 99, 100], "from": [0, 1, 6, 8, 10, 12, 13, 14, 15, 17, 19, 20, 22, 28, 29, 30, 33, 35, 38, 44, 49, 58, 62, 63, 66, 67, 70, 71, 72, 73, 76, 78, 79, 80, 86, 87, 88, 89, 91, 95, 98, 99, 100], "build": [0, 1, 45, 51, 62, 66, 67, 86, 88, 89, 99], "gener": [0, 5, 8, 10, 12, 13, 17, 44, 60, 61, 65, 70, 80, 99], "reusabl": [0, 99], "softwar": [0, 80, 99], "packag": [0, 1, 39, 90, 93, 94, 98, 99, 100], "base": [0, 4, 5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 33, 35, 38, 40, 44, 46, 48, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 75, 78, 79, 80, 82, 84, 86, 87, 88, 89, 91, 94, 95, 99], "engin": [0, 99], "best": [0, 98, 99], "practic": [0, 98, 99], "lower": [0, 76, 99], "technic": [0, 99], "threshold": [0, 44, 62, 99], "most": [0, 1, 40, 99, 100], "scientif": [0, 1, 99], "problem": [0, 62, 98, 99], "The": [0, 5, 8, 10, 12, 28, 30, 33, 35, 36, 44, 46, 48, 49, 56, 58, 62, 63, 70, 71, 72, 73, 75, 76, 78, 79, 80, 99], "improv": [0, 1, 84, 99], "classif": [0, 1, 45, 69, 72, 80, 99], "yield": [0, 56, 75, 80, 99], "veri": [0, 40, 99], "accur": [0, 99], "e": [0, 1, 5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 28, 30, 32, 33, 35, 36, 40, 44, 46, 48, 49, 51, 52, 53, 55, 59, 62, 63, 66, 67, 68, 70, 71, 72, 73, 78, 79, 80, 82, 86, 95, 98, 99, 100], "g": [0, 1, 5, 8, 10, 12, 13, 25, 28, 30, 32, 33, 35, 36, 40, 44, 49, 63, 70, 71, 72, 73, 82, 95, 98, 99, 100], "low": [0, 99], "energi": [0, 4, 33, 70, 71, 72, 82, 99], "observ": [0, 99], "icecub": [0, 1, 16, 29, 30, 45, 48, 50, 58, 94, 99, 100], "here": [0, 98, 99, 100], "implement": [0, 1, 5, 15, 31, 32, 34, 35, 48, 55, 56, 57, 58, 62, 80, 98, 99], "wa": [0, 99], "appli": [0, 8, 10, 12, 13, 15, 49, 55, 56, 57, 58, 59, 68, 90, 99], "oscil": [0, 74, 99], "lead": [0, 99], "signific": [0, 99], "angular": [0, 99], "rang": [0, 70, 71, 72, 99], "relev": [0, 1, 30, 39, 93, 98, 99], "studi": [0, 99], "furthermor": [0, 99], "shown": [0, 99], "could": [0, 98, 99], "muon": [0, 19, 99], "v": [0, 99], "therebi": [0, 1, 88, 89, 99], "effici": [0, 99], "puriti": [0, 99], "sampl": [0, 40, 99], "analysi": [0, 33, 99, 100], "similarli": [0, 30, 99], "ha": [0, 5, 30, 32, 35, 36, 44, 55, 80, 93, 99, 100], "great": [0, 99], "point": [0, 24, 79, 80, 99], "analys": [0, 41, 74, 99], "final": [0, 49, 78, 88, 99], "millisecond": [0, 99], "allow": [0, 41, 45, 49, 78, 86, 91, 99, 100], "whole": [0, 99], "new": [0, 1, 35, 48, 86, 91, 98, 99], "type": [0, 5, 6, 8, 10, 12, 13, 14, 15, 27, 28, 29, 32, 35, 36, 38, 39, 40, 46, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 72, 73, 75, 76, 78, 80, 81, 82, 84, 86, 87, 88, 89, 90, 93, 94, 95, 96, 98, 99], "cosmic": [0, 99], "alert": [0, 99], "which": [0, 8, 10, 12, 13, 15, 16, 25, 29, 33, 40, 44, 46, 49, 56, 67, 70, 75, 80, 84, 99, 100], "were": [0, 99], "previous": [0, 99], "unfeas": [0, 99], "possibl": [0, 28, 98, 99], "identifi": [0, 5, 8, 10, 12, 13, 25, 88, 89, 99], "10": [0, 33, 84, 99], "tev": [0, 99], "monitor": [0, 99], "rate": [0, 78, 99], "direct": [0, 58, 70, 71, 72, 77, 79, 99], "real": [0, 99], "thi": [0, 3, 5, 8, 10, 12, 13, 15, 17, 30, 32, 35, 36, 39, 44, 45, 49, 56, 66, 68, 70, 71, 72, 73, 75, 76, 78, 80, 82, 86, 88, 89, 91, 95, 98, 99, 100], "enabl": [0, 3, 99], "first": [0, 78, 86, 91, 98, 99], "ever": [0, 99], "despit": [0, 99], "larg": [0, 80, 99], "background": [0, 99], "origin": [0, 75, 99], "compris": [0, 99], "number": [0, 5, 8, 10, 12, 13, 32, 33, 35, 40, 48, 49, 55, 56, 57, 58, 59, 62, 64, 66, 70, 71, 72, 78, 84, 99], "modul": [0, 3, 8, 30, 33, 41, 44, 45, 48, 50, 54, 60, 61, 63, 64, 65, 67, 69, 74, 77, 83, 85, 88, 89, 90, 91, 94, 99], "necessari": [0, 28, 98, 99], "workflow": [0, 99], "ingest": [0, 1, 3, 50, 99], "raw": [0, 66, 99], "data": [0, 1, 4, 5, 6, 8, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 46, 48, 49, 50, 51, 52, 55, 56, 57, 58, 59, 62, 63, 64, 67, 68, 70, 71, 72, 73, 79, 81, 84, 86, 88, 91, 94, 97, 99, 100], "domain": [0, 1, 3, 41, 99], "specif": [0, 1, 3, 5, 8, 10, 12, 16, 30, 31, 32, 34, 35, 36, 41, 46, 49, 50, 51, 52, 53, 54, 59, 62, 63, 66, 68, 69, 70, 71, 72, 80, 98, 99, 100], "format": [0, 1, 3, 5, 8, 28, 32, 35, 76, 88, 98, 99, 100], "deploi": [0, 1, 41, 44, 99], "chain": [0, 1, 41, 45, 68, 99, 100], "illustr": [0, 98, 99], "figur": [0, 76, 99], "level": [0, 8, 10, 12, 13, 25, 36, 46, 49, 62, 67, 95, 99, 100], "overview": [0, 99], "typic": [0, 28, 99], "convert": [0, 1, 3, 5, 28, 32, 35, 38, 99, 100], "industri": [0, 3, 99], "standard": [0, 3, 4, 5, 13, 32, 35, 40, 52, 53, 63, 66, 68, 84, 98, 99], "intermedi": [0, 1, 3, 5, 8, 32, 35, 55, 99, 100], "file": [0, 1, 3, 5, 8, 10, 12, 13, 15, 28, 32, 35, 38, 39, 44, 63, 67, 75, 78, 80, 84, 85, 86, 87, 88, 89, 93, 95, 99, 100], "read": [0, 3, 8, 10, 12, 13, 28, 51, 56, 68, 69, 99, 100], "simpl": [0, 45, 99], "physic": [0, 1, 15, 29, 30, 41, 44, 45, 69, 70, 71, 72, 99], "orient": [0, 45, 99], "compon": [0, 1, 45, 48, 49, 68, 99], "manag": [0, 15, 77, 99], "experi": [0, 1, 77, 99], "log": [0, 1, 71, 77, 78, 80, 83, 99, 100], "deploy": [0, 1, 42, 44, 63, 97, 99], "modular": [0, 45, 99], "subclass": [0, 45, 99], "torch": [0, 8, 10, 12, 13, 45, 48, 63, 64, 67, 68, 94, 99, 100], "nn": [0, 45, 48, 62, 64, 99], "mean": [0, 5, 8, 10, 12, 13, 32, 35, 45, 56, 58, 80, 89, 99], "onli": [0, 1, 8, 10, 12, 13, 45, 49, 70, 71, 72, 75, 82, 89, 94, 99, 100], "need": [0, 28, 45, 67, 80, 99, 100], "import": [0, 1, 36, 45, 83, 99], "few": [0, 45, 98, 99], "exist": [0, 8, 10, 12, 13, 33, 35, 36, 45, 79, 88, 99], "purpos": [0, 45, 80, 99], "built": [0, 45, 99], "them": [0, 1, 28, 45, 56, 70, 71, 72, 75, 99, 100], "togeth": [0, 45, 62, 68, 99], "form": [0, 45, 70, 86, 91, 99], "complet": [0, 45, 68, 99], "extend": [0, 1, 99], "suit": [0, 99], "through": [0, 80, 99], "layer": [0, 45, 47, 49, 55, 56, 57, 58, 70, 71, 72, 99], "connect": [0, 62, 63, 66, 80, 99], "etc": [0, 80, 95, 99], "optimis": [0, 1, 99], "differ": [0, 8, 10, 12, 13, 15, 64, 68, 98, 99, 100], "track": [0, 15, 19, 71, 98, 99], "These": [0, 63, 98, 99], "prepar": [0, 80, 99], "satisfi": [0, 99], "o": [0, 70, 71, 72, 99], "load": [0, 6, 8, 39, 67, 86, 88, 99], "requir": [0, 21, 36, 70, 80, 88, 89, 91, 99, 100], "when": [0, 5, 8, 10, 12, 13, 28, 32, 35, 36, 44, 48, 56, 58, 79, 95, 98, 99, 100], "batch": [0, 6, 33, 46, 48, 49, 68, 73, 81, 84, 99], "do": [0, 44, 80, 88, 89, 98, 99, 100], "predict": [0, 20, 24, 26, 33, 44, 55, 67, 68, 70, 71, 72, 80, 81, 99], "either": [0, 8, 10, 12, 80, 99, 100], "contain": [0, 5, 8, 10, 12, 13, 28, 29, 32, 33, 35, 44, 56, 60, 61, 63, 64, 65, 67, 70, 71, 72, 80, 82, 84, 99, 100], "imag": [0, 1, 98, 99, 100], "portabl": [0, 99], "depend": [0, 99, 100], "free": [0, 80, 99], "split": [0, 46, 99], "up": [0, 5, 32, 35, 44, 98, 99, 100], "interfac": [0, 74, 99, 100], "block": [0, 1, 99], "pre": [0, 13, 51, 63, 79, 98, 99], "directli": [0, 15, 99], "while": [0, 17, 78, 99], "continu": [0, 80, 99], "expand": [0, 99], "": [0, 5, 6, 8, 10, 12, 13, 15, 28, 35, 38, 55, 56, 68, 70, 71, 72, 73, 78, 82, 84, 88, 89, 95, 96, 99, 100], "capabl": [0, 99], "project": [0, 98, 99], "receiv": [0, 99], "fund": [0, 99], "european": [0, 99], "union": [0, 6, 8, 10, 12, 13, 17, 28, 30, 44, 46, 48, 49, 56, 67, 68, 70, 71, 72, 88, 91, 93, 99], "horizon": [0, 99], "2020": [0, 99], "innov": [0, 99], "programm": [0, 99], "under": [0, 13, 99], "mari": [0, 99], "sk\u0142odowska": [0, 99], "curi": [0, 99], "grant": [0, 80, 99], "agreement": [0, 98, 99], "No": [0, 99], "890778": [0, 99], "work": [0, 4, 29, 98, 99, 100], "rasmu": [0, 57, 99], "\u00f8rs\u00f8e": [0, 99], "partli": [0, 99], "punch4nfdi": [0, 99], "consortium": [0, 99], "support": [0, 30, 98, 99, 100], "dfg": [0, 99], "nfdi": [0, 99], "39": [0, 99, 100], "1": [0, 5, 8, 28, 32, 35, 40, 46, 49, 56, 58, 62, 64, 70, 71, 72, 73, 78, 80, 82, 88, 99, 100], "germani": [0, 99], "conveni": [1, 98, 100], "collabor": 1, "solv": [1, 98], "It": [1, 28, 36, 44, 98], "leverag": 1, "advanc": [1, 49], "machin": [1, 100], "learn": [1, 44, 78, 100], "without": [1, 62, 66, 75, 80, 100], "have": [1, 5, 17, 32, 35, 36, 40, 49, 63, 70, 71, 72, 98, 100], "expert": 1, "themselv": [1, 88, 89], "acceler": 1, "area": 1, "phyic": 1, "design": 1, "principl": 1, "all": [1, 5, 8, 10, 12, 13, 15, 17, 32, 35, 36, 44, 48, 49, 51, 56, 59, 63, 67, 72, 80, 86, 87, 88, 89, 90, 91, 95, 98, 100], "streamlin": 1, "process": [1, 5, 13, 15, 44, 51, 56, 98, 100], "transform": [1, 49, 70, 71, 72, 82], "extens": [1, 93], "basic": 1, "across": [1, 2, 8, 10, 12, 13, 30, 37, 49, 68, 80, 83, 84, 85, 95], "variou": 1, "easili": 1, "architectur": [1, 55, 56, 57, 58, 68], "main": [1, 54, 63, 68, 98, 100], "featur": [1, 3, 4, 5, 8, 10, 12, 13, 16, 33, 44, 48, 49, 51, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 70, 73, 81, 88, 98], "i3": [1, 5, 15, 29, 30, 32, 35, 39, 44, 93, 100], "more": [1, 8, 36, 39, 86, 88, 89, 91, 95], "index": [1, 5, 8, 10, 12, 30, 36, 49, 78], "sqlite": [1, 3, 7, 12, 13, 33, 35, 36, 38, 100], "suitabl": 1, "plug": 1, "plai": 1, "abstract": [1, 5, 8, 51, 59, 63, 67, 72, 87], "awai": 1, "detail": [1, 100], "expos": 1, "physicst": 1, "what": [1, 63, 98], "i3modul": [1, 41, 44], "includ": [1, 13, 67, 68, 75, 80, 86, 98], "docker": 1, "run": [1, 38], "containeris": 1, "fashion": 1, "subpackag": [1, 3, 7, 14, 41, 45, 60, 83], "dataset": [1, 3, 6, 9, 10, 11, 12, 13, 19, 40, 63, 84, 88], "extractor": [1, 3, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 35, 44], "parquet": [1, 3, 7, 10, 32, 38, 100], "util": [1, 3, 14, 28, 29, 30, 36, 38, 39, 40, 45, 77, 84, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96, 97], "constant": [1, 3, 97], "dataconvert": [1, 3, 32, 35], "dataload": [1, 3, 33, 63, 67, 68, 81, 91], "pipelin": [1, 3], "coarsen": [1, 45, 49], "standard_model": [1, 45], "pisa": [1, 21, 33, 75, 76, 94, 97, 100], "fit": [1, 67, 74, 76, 80, 82, 91], "plot": [1, 74], "callback": [1, 67, 77], "label": [1, 8, 19, 22, 55, 63, 68, 72, 76, 77, 81], "loss_funct": [1, 70, 71, 72, 77], "weight_fit": [1, 77], "config": [1, 6, 40, 75, 80, 83, 84, 86, 87, 88, 89, 90, 91], "argpars": [1, 83], "decor": [1, 5, 83, 94], "filesi": [1, 83], "math": [1, 83], "submodul": [1, 3, 7, 9, 11, 14, 27, 31, 34, 37, 42, 45, 47, 50, 54, 60, 61, 65, 69, 74, 77, 83, 85, 90], "global": [2, 4, 56, 58, 67], "i3extractor": [3, 5, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35], "i3featureextractor": [3, 4, 14, 35, 44], "i3genericextractor": [3, 14, 35], "i3hybridrecoextractor": [3, 14], "i3ntmuonlabelsextractor": [3, 14], "i3particleextractor": [3, 14], "i3pisaextractor": [3, 14], "i3quesoextractor": [3, 14], "i3retroextractor": [3, 14], "i3splinempeextractor": [3, 14], "i3truthextractor": [3, 4, 14], "i3tumextractor": [3, 14], "parquet_dataconvert": [3, 31], "sqlite_dataconvert": [3, 34], "sqlite_util": [3, 34], "parquet_to_sqlit": [3, 37], "random": [3, 8, 10, 12, 13, 37, 40, 88], "string_selection_resolv": [3, 37], "truth": [3, 4, 8, 10, 12, 13, 16, 25, 33, 36, 63, 81, 82, 88], "fileset": [3, 5], "init_global_index": [3, 5], "cache_output_fil": [3, 5], "collate_fn": [3, 6, 77, 81], "do_shuffl": [3, 6], "insqlitepipelin": [3, 33], "class": [4, 5, 6, 7, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 30, 31, 32, 33, 34, 35, 38, 40, 44, 46, 48, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 75, 78, 79, 80, 82, 84, 86, 87, 88, 89, 90, 91, 95, 98], "object": [4, 5, 8, 10, 12, 13, 15, 17, 28, 30, 44, 49, 51, 63, 70, 71, 72, 75, 84, 95], "namespac": [4, 67], "name": [4, 5, 6, 8, 10, 12, 13, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 30, 32, 33, 35, 36, 38, 44, 62, 63, 64, 66, 67, 70, 71, 72, 75, 79, 82, 84, 86, 88, 89, 90, 91, 95, 98, 100], "icecube86": [4, 50, 52], "dom_x": [4, 44, 46], "dom_i": [4, 44, 46], "dom_z": [4, 44, 46], "dom_tim": 4, "charg": [4, 44, 80], "rde": [4, 46], "pmt_area": [4, 46], "deepcor": [4, 16, 52], "upgrad": [4, 16, 52, 100], "string": [4, 5, 8, 10, 12, 13, 28, 32, 35, 40, 49, 86], "pmt_number": 4, "dom_numb": 4, "pmt_dir_x": 4, "pmt_dir_i": 4, "pmt_dir_z": 4, "dom_typ": 4, "prometheu": [4, 45, 50], "sensor_pos_x": 4, "sensor_pos_i": 4, "sensor_pos_z": 4, "t": [4, 30, 36, 76, 78, 80, 100], "kaggl": [4, 48, 52, 58], "x": [4, 5, 25, 32, 35, 48, 49, 66, 67, 72, 73, 76, 80, 82], "y": [4, 25, 73, 76, 100], "z": [4, 5, 25, 32, 35, 73, 100], "auxiliari": 4, "energy_track": 4, "position_x": 4, "position_i": 4, "position_z": 4, "azimuth": [4, 71, 79], "zenith": [4, 71, 79], "pid": [4, 40, 88], "elast": 4, "sim_typ": 4, "interaction_typ": 4, "interaction_tim": [4, 71], "inelast": [4, 71], "stopped_muon": 4, "injection_energi": 4, "injection_typ": 4, "injection_interaction_typ": 4, "injection_zenith": 4, "injection_azimuth": 4, "injection_bjorkenx": 4, "injection_bjorkeni": 4, "injection_position_x": 4, "injection_position_i": 4, "injection_position_z": 4, "injection_column_depth": 4, "primary_lepton_1_typ": 4, "primary_hadron_1_typ": 4, "primary_lepton_1_position_x": 4, "primary_lepton_1_position_i": 4, "primary_lepton_1_position_z": 4, "primary_hadron_1_position_x": 4, "primary_hadron_1_position_i": 4, "primary_hadron_1_position_z": 4, "primary_lepton_1_direction_theta": 4, "primary_lepton_1_direction_phi": 4, "primary_hadron_1_direction_theta": 4, "primary_hadron_1_direction_phi": 4, "primary_lepton_1_energi": 4, "primary_hadron_1_energi": 4, "total_energi": 4, "i3_fil": [5, 15], "str": [5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 33, 35, 36, 38, 39, 40, 44, 46, 48, 49, 51, 52, 53, 56, 58, 62, 63, 64, 66, 67, 68, 70, 71, 72, 75, 79, 81, 82, 84, 86, 87, 88, 89, 90, 91, 93, 95], "gcd_file": [5, 15, 44], "paramet": [5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 33, 35, 36, 38, 39, 40, 44, 46, 48, 49, 51, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 75, 76, 78, 79, 80, 81, 82, 84, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96], "output_fil": [5, 32, 35], "global_index": 5, "avail": [5, 17, 33, 94], "pool": [5, 45, 46, 47, 56, 58], "worker": [5, 32, 33, 35, 39, 84, 95], "return": [5, 6, 8, 10, 12, 13, 15, 28, 29, 30, 32, 35, 36, 38, 39, 40, 46, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 66, 67, 68, 70, 72, 73, 75, 76, 78, 79, 80, 81, 82, 84, 86, 87, 88, 89, 90, 93, 94, 95, 96], "none": [5, 6, 8, 10, 12, 13, 15, 17, 25, 29, 30, 32, 33, 35, 36, 38, 40, 44, 46, 48, 49, 56, 58, 62, 63, 64, 66, 67, 68, 70, 71, 72, 75, 78, 80, 81, 82, 84, 86, 87, 88, 90, 93, 95], "synchron": 5, "list": [5, 6, 8, 10, 12, 13, 15, 17, 25, 28, 30, 32, 33, 35, 36, 38, 39, 40, 44, 46, 48, 49, 51, 56, 58, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 76, 78, 80, 81, 82, 88, 90, 91, 93, 95], "process_method": 5, "cach": 5, "output": [5, 32, 35, 38, 55, 56, 57, 59, 66, 67, 68, 75, 82, 88, 89, 100], "typevar": 5, "f": [5, 49], "bound": [5, 76], "callabl": [5, 6, 8, 30, 48, 49, 51, 52, 53, 63, 70, 71, 72, 81, 82, 86, 88, 89, 90, 94], "ani": [5, 6, 8, 10, 12, 28, 29, 30, 32, 35, 44, 48, 49, 56, 62, 63, 67, 68, 70, 72, 76, 80, 82, 84, 86, 87, 88, 89, 90, 91, 95, 100], "outdir": [5, 32, 33, 35, 38, 75], "gcd_rescu": [5, 32, 35, 93], "nb_files_to_batch": [5, 32, 35], "sequential_batch_pattern": [5, 32, 35], "input_file_batch_pattern": [5, 32, 35], "index_column": [5, 8, 10, 12, 13, 32, 35, 36, 40, 75, 81, 82, 88], "icetray_verbos": [5, 32, 35], "abc": [5, 8, 15, 33, 67, 79, 82, 87], "logger": [5, 8, 15, 33, 38, 40, 62, 67, 79, 82, 83, 95, 100], "construct": [5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35, 38, 40, 46, 47, 48, 51, 52, 53, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 70, 71, 72, 75, 78, 79, 80, 81, 82, 84, 87, 88, 89, 95], "regular": [5, 30, 32, 35], "express": [5, 32, 35, 67, 80], "accord": [5, 13, 32, 35, 46, 49, 62], "match": [5, 32, 35, 82, 93, 96], "certain": [5, 32, 35, 38, 75], "pattern": [5, 32, 35], "wildcard": [5, 32, 35], "same": [5, 30, 32, 35, 36, 46, 49, 70, 73, 78, 90, 95], "input": [5, 8, 10, 12, 13, 32, 33, 35, 44, 52, 55, 56, 57, 58, 59, 63, 66, 70, 72, 73, 86, 91], "replac": [5, 32, 35, 86, 88, 89, 91], "period": [5, 32, 35], "special": [5, 17, 32, 35, 44, 73], "interpret": [5, 32, 35, 70], "liter": [5, 32, 35], "charact": [5, 32, 35], "regex": [5, 32, 35], "For": [5, 30, 32, 35, 78], "instanc": [5, 8, 15, 25, 30, 32, 35, 44, 63, 67, 75, 79, 81, 87, 100], "A": [5, 8, 32, 33, 35, 44, 49, 64, 73, 75, 80, 82, 100], "_": [5, 32, 35], "0": [5, 8, 10, 12, 32, 35, 40, 44, 46, 49, 55, 56, 58, 62, 64, 73, 75, 76, 80, 88], "9": [5, 32, 35], "5": [5, 8, 10, 12, 32, 35, 40, 84, 100], "zst": [5, 32, 35], "find": [5, 32, 35, 93], "whose": [5, 32, 35, 44], "one": [5, 8, 32, 35, 36, 44, 49, 67, 88, 89, 93, 98, 100], "capit": [5, 32, 35], "letter": [5, 32, 35], "follow": [5, 32, 35, 56, 68, 80, 82, 98, 100], "underscor": [5, 32, 35], "five": [5, 32, 35], "upgrade_genie_step4_141020_a_000000": [5, 32, 35], "upgrade_genie_step4_141020_a_000001": [5, 32, 35], "upgrade_genie_step4_141020_a_000008": [5, 32, 35], "upgrade_genie_step4_141020_a_000009": [5, 32, 35], "would": [5, 32, 35, 98], "upgrade_genie_step4_141020_a_00000x": [5, 32, 35], "suffix": [5, 32, 35], "upgrade_genie_step4_141020_a_000010": [5, 32, 35], "separ": [5, 28, 32, 35, 78, 100], "upgrade_genie_step4_141020_a_00001x": [5, 32, 35], "int": [5, 6, 8, 10, 12, 13, 19, 22, 32, 33, 35, 40, 48, 49, 55, 56, 57, 58, 59, 62, 64, 66, 67, 68, 70, 71, 72, 73, 75, 78, 80, 81, 82, 84, 88, 91, 95], "properti": [5, 8, 15, 20, 30, 49, 59, 66, 68, 72, 79, 87, 95], "file_suffix": [5, 32, 35], "execut": [5, 36], "method": [5, 8, 10, 12, 15, 27, 28, 29, 30, 32, 35, 44, 48, 49, 71, 80, 82], "set": [5, 17, 70, 71, 72, 98], "inherit": [5, 15, 30, 51, 66, 80, 95], "path": [5, 8, 10, 12, 13, 36, 39, 44, 63, 67, 75, 76, 84, 86, 87, 88, 93, 100], "correspond": [5, 8, 10, 12, 13, 28, 30, 35, 39, 56, 63, 82, 93, 100], "gcd": [5, 15, 29, 39, 44, 93], "save_data": [5, 32, 35], "save": [5, 15, 28, 32, 35, 36, 67, 75, 80, 81, 82, 86, 87, 88, 89, 100], "ordereddict": [5, 32, 35], "extract": [5, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 35, 38, 39, 44, 70, 71, 72], "merge_fil": [5, 32, 35], "input_fil": [5, 32, 35], "merg": [5, 32, 35, 80, 100], "result": [5, 32, 35, 49, 78, 80, 81, 90, 100], "option": [5, 8, 10, 12, 13, 25, 32, 33, 35, 44, 48, 49, 56, 58, 63, 64, 67, 70, 71, 72, 75, 76, 80, 82, 83, 84, 86, 88, 93, 100], "default": [5, 8, 10, 12, 13, 17, 25, 28, 32, 33, 35, 36, 38, 44, 48, 49, 55, 56, 57, 58, 62, 63, 64, 66, 67, 70, 71, 72, 75, 76, 78, 79, 80, 82, 84, 86, 88, 93], "current": [5, 32, 35, 40, 78, 98, 100], "rais": [5, 8, 17, 32, 67, 86, 91], "notimplementederror": [5, 32], "If": [5, 8, 17, 32, 33, 35, 67, 70, 71, 72, 75, 78, 82, 98, 100], "been": [5, 32, 44, 80, 98], "backend": [5, 9, 11, 32, 35], "question": 5, "get_map_funct": 5, "nb_file": 5, "map": [5, 8, 10, 12, 13, 16, 17, 35, 36, 44, 52, 53, 86, 88, 89, 91], "pure": [5, 14, 15, 17, 30], "multiprocess": [5, 100], "tupl": [5, 8, 10, 12, 29, 30, 48, 56, 58, 70, 71, 72, 73, 75, 76, 81, 84], "remov": [6, 81, 84], "less": [6, 81], "two": [6, 56, 75, 78, 80, 81], "dom": [6, 8, 10, 12, 13, 46, 49, 81], "hit": [6, 81], "should": [6, 8, 10, 12, 13, 15, 28, 40, 48, 49, 80, 81, 86, 88, 89, 91, 98, 100], "occur": [6, 81], "product": [6, 81], "selection_nam": 6, "check": [6, 29, 30, 35, 36, 84, 93, 94, 98, 100], "whether": [6, 29, 30, 35, 36, 56, 67, 80, 90, 93, 94], "shuffl": [6, 39, 81], "select": [6, 8, 10, 12, 13, 22, 40, 81, 82, 88, 98], "bool": [6, 29, 30, 35, 36, 40, 44, 46, 56, 67, 68, 75, 78, 80, 81, 82, 84, 90, 93, 94, 95], "batch_siz": [6, 33, 73, 81], "num_work": [6, 81], "persistent_work": [6, 81], "prefetch_factor": 6, "kwarg": [6, 48, 62, 67, 70, 72, 80, 82, 86, 95], "t_co": 6, "classmethod": [6, 8, 67, 80, 86, 87], "from_dataset_config": 6, "datasetconfig": [6, 8, 40, 85, 88], "dict": [6, 8, 13, 17, 28, 30, 33, 35, 51, 52, 53, 63, 67, 68, 75, 76, 78, 80, 81, 84, 86, 88, 89, 90, 91], "parquet_dataset": [7, 9], "sqlite_dataset": [7, 11], "sqlite_dataset_perturb": [7, 11], "columnmissingexcept": [7, 8], "load_modul": [7, 8, 67], "parse_graph_definit": [7, 8], "ensembledataset": [7, 8, 88], "except": 8, "indic": [8, 40, 49, 78, 84, 98], "miss": 8, "column": [8, 10, 12, 13, 36, 44, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 75, 82], "class_nam": [8, 62, 67, 89, 95], "cfg": 8, "graphdefinit": [8, 10, 12, 44, 60, 61, 63, 64, 65, 68, 81, 98], "graph_definit": [8, 10, 12, 44, 45, 60, 68, 81, 88], "pulsemap": [8, 10, 12, 13, 16, 35, 44, 81, 88], "node_truth": [8, 10, 12, 13, 81, 88], "truth_tabl": [8, 10, 12, 13, 75, 81, 82, 88], "node_truth_t": [8, 10, 12, 13, 81, 88], "string_select": [8, 10, 12, 13, 81, 88], "dtype": [8, 10, 12, 13, 63, 64, 96], "loss_weight_t": [8, 10, 12, 13, 81, 88], "loss_weight_column": [8, 10, 12, 13, 63, 81, 88], "loss_weight_default_valu": [8, 10, 12, 13, 63, 88], "seed": [8, 10, 12, 13, 40, 81, 88], "puls": [8, 10, 12, 13, 16, 17, 29, 30, 35, 36, 44, 46, 49, 66, 73], "seri": [8, 10, 12, 13, 16, 17, 29, 30, 36, 44], "node": [8, 10, 12, 13, 45, 46, 49, 55, 56, 58, 60, 61, 62, 63, 64, 70, 71, 72, 73], "multipl": [8, 10, 12, 13, 15, 78, 88, 95], "store": [8, 10, 12, 13, 15, 33, 36, 75, 79], "ad": [8, 10, 12, 13, 16, 56, 63, 75], "attribut": [8, 10, 12, 13, 46, 70, 71, 72], "event_no": [8, 10, 12, 13, 36, 40, 82, 88], "uniqu": [8, 10, 12, 13, 36, 38, 88], "indici": [8, 10, 12, 13, 29, 40, 80], "tabl": [8, 10, 12, 13, 15, 33, 35, 36, 63, 75, 82], "inform": [8, 10, 12, 13, 15, 17, 25, 76], "subset": [8, 10, 12, 13, 48, 56, 58], "given": [8, 10, 12, 13, 35, 49, 62, 70, 71, 72, 82, 84], "queri": [8, 10, 12, 36, 40], "pass": [8, 10, 12, 48, 55, 56, 57, 58, 59, 63, 67, 68, 70, 71, 72, 80, 82, 98], "float32": [8, 10, 12, 13, 63, 64], "tensor": [8, 10, 12, 13, 46, 48, 49, 51, 55, 56, 57, 58, 59, 66, 67, 68, 70, 71, 72, 73, 80, 96], "per": [8, 10, 12, 13, 17, 36, 49, 70, 71, 72, 80, 82], "loss": [8, 10, 12, 13, 63, 68, 70, 71, 72, 78, 80, 84], "weight": [8, 10, 12, 13, 44, 63, 70, 71, 72, 75, 80, 82, 89, 100], "also": [8, 10, 12, 13, 40, 88], "assign": [8, 10, 12, 13, 38, 46, 49, 98], "float": [8, 10, 12, 13, 44, 46, 55, 62, 63, 67, 75, 76, 78, 80, 81, 88], "note": [8, 10, 12, 13, 76, 89], "valu": [8, 10, 12, 13, 25, 28, 35, 36, 49, 63, 76, 79, 80, 84, 86], "specifi": [8, 10, 12, 13, 40, 46, 70, 71, 72, 76, 78, 100], "case": [8, 10, 12, 13, 17, 44, 49, 70, 71, 72, 100], "That": [8, 10, 12, 13, 56, 71, 79], "ignor": [8, 10, 12, 13, 30], "resolv": [8, 10, 12, 40], "10000": [8, 10, 12, 40], "20": [8, 10, 12, 40, 95], "defin": [8, 10, 12, 40, 44, 49, 60, 61, 62, 63, 65, 86, 88, 89, 91], "represent": [8, 10, 12, 30, 49, 64], "from_config": [8, 67, 87, 88, 89], "concaten": [8, 28, 56], "query_t": [8, 10, 12], "sequential_index": [8, 10, 12], "some": [8, 10, 12, 63], "out": [8, 56, 68, 69, 80, 95, 98, 100], "sequenti": 8, "len": 8, "self": [8, 63, 75, 86, 91], "_may_": 8, "_indic": 8, "entir": [8, 67], "impos": 8, "befor": [8, 56, 70, 71, 72, 78], "scalar": [8, 73, 80], "length": [8, 30, 78], "element": [8, 28, 30, 68, 73, 90], "present": [8, 84, 93, 94], "add_label": 8, "fn": [8, 30, 86, 90], "kei": [8, 17, 28, 29, 30, 35, 36, 46, 49, 79, 88, 89], "add": [8, 56, 84, 98, 100], "custom": [8, 63, 78], "concatdataset": 8, "singl": [8, 15, 49, 56, 79, 88, 89], "collect": [8, 14, 15, 27, 80, 96], "iter": 8, "parquetdataset": [9, 10], "pytorch": [10, 12, 13, 78, 100], "sqlitedataset": [11, 12, 13], "sqlitedatasetperturb": [11, 13], "databas": [12, 13, 33, 35, 36, 38, 75, 82, 100], "perturb": 13, "perturbation_dict": 13, "step": [13, 68, 78], "where": [13, 63, 64, 66, 79], "randomli": [13, 40, 89], "nois": [13, 16, 29, 44], "intend": [13, 100], "test": [13, 40, 70, 71, 72, 81, 88, 94, 98], "stabil": 13, "small": [13, 80], "chang": [13, 75, 80, 98], "dictionari": [13, 28, 29, 30, 33, 35, 63, 75, 76, 86, 88, 89, 91], "deviat": 13, "i3fram": [14, 15, 17, 29, 30, 44], "frame": [14, 15, 17, 27, 30, 35, 44], "i3extractorcollect": [14, 15], "i3featureextractoricecube86": [14, 16], "i3featureextractoricecubedeepcor": [14, 16], "i3featureextractoricecubeupgrad": [14, 16], "i3pulsenoisetruthflagicecubeupgrad": [14, 16], "i3galacticplanehybridrecoextractor": [14, 18], "i3ntmuonlabelextractor": [14, 19], "i3splinempeicextractor": [14, 24], "__call__": 15, "icetrai": [15, 29, 30, 44, 94], "keep": 15, "proven": 15, "set_fil": 15, "refer": [15, 88], "being": [15, 44, 70, 71, 72], "get": [15, 29, 78, 81, 100], "treat": 15, "86": [16, 52], "flag": [16, 44], "exclude_kei": 17, "dynam": [17, 48, 56, 57, 58], "pars": [17, 76, 83, 84, 85, 86, 91], "call": [17, 30, 35, 49, 75, 82, 95], "tri": [17, 30], "automat": [17, 80, 98], "cast": [17, 30], "done": [17, 49, 95, 98], "recurs": [17, 30, 90, 93], "each": [17, 28, 30, 36, 38, 39, 46, 49, 52, 53, 56, 58, 62, 63, 64, 66, 67, 70, 71, 72, 73, 75, 76, 78, 93], "look": [17, 100], "member": [17, 30, 88, 89, 95], "variabl": [17, 30, 56, 73, 82, 95], "signatur": [17, 30], "similar": [17, 30, 100], "handl": [17, 80, 84, 95], "hand": 17, "mc": [17, 35, 36], "tree": [17, 35], "trigger": 17, "exclud": [17, 38, 100], "valueerror": [17, 67], "hybrid": 18, "galatict": 18, "plane": [18, 80], "tum": [19, 26], "dnn": [19, 26], "padding_valu": [19, 22], "northeren": 19, "i3particl": 20, "other": [20, 36, 62, 80, 98], "algorithm": 20, "comparison": [20, 80], "quantiti": [21, 70, 71, 72, 73], "queso": 22, "retro": [23, 33], "splinemp": 24, "border": 25, "mctree": [25, 29], "ndarrai": [25, 63, 82], "arrai": [25, 28], "boundari": 25, "volum": 25, "coordin": [25, 73], "particl": [25, 36, 79], "start": [25, 98, 100], "stop": [25, 84], "within": [25, 46, 48, 49, 56, 62], "hard": 25, "i3mctre": 25, "flatten_nested_dictionari": [27, 28], "serialis": [27, 28], "transpose_list_of_dict": [27, 28], "frame_is_montecarlo": [27, 29], "frame_is_nois": [27, 29], "get_om_keys_and_pulseseri": [27, 29], "is_boost_enum": [27, 30], "is_boost_class": [27, 30], "is_icecube_class": [27, 30], "is_typ": [27, 30], "is_method": [27, 30], "break_cyclic_recurs": [27, 30], "get_member_vari": [27, 30], "cast_object_to_pure_python": [27, 30], "cast_pulse_series_to_pure_python": [27, 30], "manipul": [28, 60, 61, 65], "obj": [28, 30, 90], "parent_kei": 28, "flatten": 28, "nest": 28, "non": [28, 30, 35, 36, 80], "exampl": [28, 40, 46, 49, 80, 88, 89, 100], "d": [28, 63, 66, 98], "b": [28, 46, 49], "c": [28, 49, 80, 100], "2": [28, 49, 56, 58, 62, 64, 71, 73, 75, 76, 80, 88, 100], "a__b": 28, "applic": 28, "combin": [28, 88], "parent": 28, "__": [28, 30], "nester": 28, "json": [28, 88], "therefor": 28, "we": [28, 30, 40, 98, 100], "outer": 28, "abl": [28, 100], "de": 28, "transpos": 28, "mont": 29, "carlo": 29, "simul": [29, 44], "pulseseri": 29, "calibr": [29, 30], "gcd_dict": [29, 30], "p": [29, 35, 80], "om": [29, 30], "dataclass": 29, "i3calibr": 29, "indicesfor": 29, "boost": 30, "enum": 30, "ensur": [30, 39, 80, 95, 98, 100], "isn": 30, "return_discard": 30, "valid": [30, 40, 68, 70, 71, 72, 80, 84, 86, 91], "mangl": 30, "take": [30, 35, 49, 98], "mainli": 30, "cannot": [30, 86, 91], "trivial": [30, 72], "doe": [30, 89], "try": 30, "equival": 30, "its": 30, "like": [30, 49, 73, 80, 96, 98], "otherwis": [30, 80], "itself": [30, 70, 71, 72], "deem": 30, "wai": [30, 40, 98, 100], "optic": 30, "found": [30, 80], "parquetdataconvert": [31, 32], "module_dict": 33, "devic": 33, "retro_table_nam": 33, "n_worker": [33, 75], "pipeline_nam": 33, "creat": [33, 35, 36, 63, 86, 87, 91, 98, 100], "initialis": [33, 89], "gnn_module_for_energy_regress": 33, "modulelist": 33, "comput": [33, 68, 70, 71, 72, 73, 80], "directori": [33, 38, 75, 93], "100": [33, 100], "size": [33, 48, 49, 56, 57, 58, 84], "alreadi": [33, 36, 100], "error": [33, 80, 95, 98], "prompt": 33, "avoid": [33, 95, 98], "overwrit": [33, 78], "sqlitedataconvert": [34, 35, 100], "construct_datafram": [34, 35], "is_pulse_map": [34, 35], "is_mc_tre": [34, 35], "database_exist": [34, 36], "database_table_exist": [34, 36], "run_sql_cod": [34, 36], "save_to_sql": [34, 36], "attach_index": [34, 36], "create_t": [34, 36], "create_table_and_save_to_sql": [34, 36], "db": [35, 81], "max_table_s": 35, "maximum": [35, 49, 70, 71, 72, 84], "row": [35, 36], "exce": 35, "limit": [35, 80], "any_pulsemap_is_non_empti": 35, "data_dict": 35, "empti": [35, 44], "retriev": 35, "splitinicepuls": 35, "least": [35, 98, 100], "true": [35, 36, 44, 75, 78, 80, 82, 88, 89, 91], "becaus": [35, 39], "instead": [35, 80, 86, 91], "alwai": 35, "panda": [35, 40, 82], "datafram": [35, 36, 40, 67, 68, 75, 81, 82], "table_nam": [35, 36], "database_path": [36, 75, 82], "df": 36, "must": [36, 46, 78, 82, 98], "attach": 36, "default_typ": 36, "null": 36, "integer_primary_kei": 36, "NOT": [36, 80], "integ": [36, 56, 57, 80], "primari": 36, "Such": 36, "appropri": [36, 70, 71, 72], "expect": [36, 40, 44, 66], "doesn": 36, "parquettosqliteconvert": [37, 38], "pairwise_shuffl": [37, 39], "stringselectionresolv": [37, 40], "parquet_path": 38, "mc_truth_tabl": 38, "excluded_field": 38, "id": 38, "everi": [38, 100], "field": [38, 76, 79, 86, 88, 89, 91], "One": [38, 76], "choos": 38, "argument": [38, 82, 84, 86, 88, 89, 91], "exclude_field": 38, "database_nam": 38, "convers": [38, 100], "rng": 39, "relat": [39, 93], "i3_list": [39, 93], "gcd_list": [39, 93], "correpond": 39, "handi": 39, "even": 39, "files_list": 39, "gcd_shuffl": 39, "i3_shuffl": 39, "use_cach": 40, "flexibl": 40, "below": [40, 76, 82, 98, 100], "show": [40, 78], "involv": 40, "cover": 40, "yml": [40, 84, 88, 89], "50000": [40, 88], "ab": [40, 80, 88], "12": [40, 88], "14": [40, 88], "16": [40, 88], "13": [40, 100], "compat": 40, "syntax": [40, 80], "mai": [40, 66, 100], "fix": 40, "graphnet_modul": [41, 42], "graphneti3modul": [42, 44], "i3inferencemodul": [42, 44], "i3pulsecleanermodul": [42, 44], "pulsemap_extractor": 44, "produc": [44, 79, 82], "write": [44, 100], "constructor": 44, "knngraph": [44, 60, 64], "associ": [44, 63, 71, 80], "model_config": [44, 83, 85, 86, 88, 91], "state_dict": [44, 67], "model_nam": [44, 75], "prediction_column": [44, 67, 68, 81], "pulsmap": 44, "modelconfig": [44, 67, 85, 88, 89], "summar": 44, "Will": [44, 62], "help": [44, 84, 98], "entri": [44, 56, 76, 84], "dynedg": [44, 45, 54, 57, 58], "energy_reco": 44, "discard_empty_ev": 44, "clean": [44, 98, 100], "assum": [44, 51, 72, 73], "7": [44, 49, 75], "consid": [44, 100], "posit": [44, 49, 71], "signal": 44, "els": 44, "fals": [44, 56, 67, 75, 78, 80, 82, 88], "elimin": 44, "speed": 44, "especi": 44, "sinc": [44, 80], "further": 44, "calcul": [44, 62, 64, 68, 73, 79, 80], "convnet": [45, 54], "dynedge_jinst": [45, 54], "dynedge_kaggle_tito": [45, 54], "edg": [45, 48, 49, 56, 57, 58, 60, 63, 64, 65, 66, 73], "unbatch_edge_index": [45, 46], "attributecoarsen": [45, 46], "domcoarsen": [45, 46], "customdomcoarsen": [45, 46], "domandtimewindowcoarsen": [45, 46], "standardmodel": [45, 68], "calculate_xyzt_homophili": [45, 73], "calculate_distance_matrix": [45, 73], "knn_graph_batch": [45, 73], "oper": [46, 48, 54, 56], "cluster": [46, 48, 49, 56, 58], "local": [46, 84], "edge_index": [46, 48, 73], "vector": [46, 49, 80], "longtensor": [46, 49, 73], "mathbf": [46, 49], "ldot": [46, 49], "n": [46, 49, 80], "reduc": 46, "transfer_attribut": 46, "reduce_opt": 46, "avg": 46, "avg_pool": 46, "avg_pool_x": 46, "max": [46, 48, 56, 58, 80, 84], "max_pool": [46, 49], "max_pool_x": [46, 49], "min": [46, 49, 56, 58], "min_pool": [46, 47, 49], "min_pool_x": [46, 47, 49], "sum": [46, 49, 56, 58, 68], "sum_pool": [46, 47, 49], "sum_pool_x": [46, 47, 49], "forward": [46, 48, 51, 55, 56, 57, 58, 59, 62, 63, 66, 67, 68, 72, 80], "simplecoarsen": 46, "addit": [46, 48, 67, 68, 80, 82], "time_window": 46, "time_kei": 46, "window": 46, "dynedgeconv": [47, 48, 56], "edgeconvtito": [47, 48], "dyntran": [47, 48, 58], "sum_pool_and_distribut": [47, 49], "group_bi": [47, 49], "group_pulses_to_dom": [47, 49], "group_pulses_to_pmt": [47, 49], "std_pool_x": [47, 49], "std_pool": [47, 49], "aggr": 48, "nb_neighbor": 48, "features_subset": [48, 56, 58], "edgeconv": 48, "lightningmodul": [48, 67, 78, 95], "convolut": [48, 55, 56, 57, 58], "mlp": [48, 56], "aggreg": [48, 49], "8": [48, 49, 56, 64, 80, 98, 100], "neighbour": [48, 56, 58, 62, 64, 73], "after": [48, 56, 78, 84], "sequenc": 48, "slice": [48, 56, 58], "sparsetensor": 48, "messagepass": 48, "tito": [48, 58], "solut": [48, 58, 98], "deep": [48, 58], "competit": [48, 52, 58], "reset_paramet": 48, "reset": 48, "learnabl": [48, 54, 55, 56, 57, 58, 59], "messag": [48, 78, 95], "x_i": 48, "x_j": 48, "layer_s": 48, "n_head": 48, "dyntrans1": 48, "head": 48, "multiheadattent": 48, "just": [49, 100], "negat": 49, "cluster_index": 49, "distribut": [49, 56, 71, 80, 82], "ident": [49, 72], "pmt": 49, "f1": 49, "f2": 49, "6": [49, 76], "groupbi": 49, "3": [49, 55, 58, 71, 73, 75, 76, 80, 98, 100], "matrix": [49, 62, 73, 80], "mathbb": 49, "r": [49, 62, 100], "n_1": 49, "n_b": 49, "obtain": [49, 80], "wise": 49, "dens": 49, "fc": 49, "known": 49, "std": 49, "repres": [49, 63, 64, 66, 86, 88, 89], "averag": [49, 80], "torch_geometr": 49, "version": [49, 70, 71, 72, 78, 98, 100], "standardis": 50, "icecubekaggl": [50, 52], "icecubedeepcor": [50, 52], "icecubeupgrad": [50, 52], "ins": 51, "feature_map": [51, 52, 53], "node_featur": [51, 63], "node_feature_nam": [51, 63, 64, 66], "adjac": 51, "dimens": [52, 53, 55, 56, 58, 80], "prototyp": 53, "dynedgejinst": [54, 57], "dynedgetito": [54, 58], "author": [55, 57, 80], "martin": 55, "minh": 55, "nb_input": [55, 56, 57, 58, 59, 70, 71, 72], "nb_output": [55, 57, 59, 66, 70, 72], "nb_intermedi": 55, "dropout_ratio": 55, "128": [55, 56, 84], "fraction": 55, "drop": 55, "nb_neighbour": 56, "dynedge_layer_s": 56, "post_processing_layer_s": 56, "readout_layer_s": 56, "global_pooling_schem": [56, 58], "add_global_variables_after_pool": 56, "k": [56, 58, 62, 64, 73, 80], "nearest": [56, 58, 62, 64, 73], "latent": [56, 58, 70], "metric": [56, 58, 78], "dimenion": [56, 58], "multi": 56, "perceptron": 56, "256": 56, "336": 56, "hidden": [56, 57, 70, 72], "skip": 56, "post": 56, "_and_": 56, "As": 56, "last": [56, 70, 72, 78], "scheme": [56, 58], "altern": [56, 80, 98], "exact": [57, 80], "2209": 57, "03042": 57, "oerso": 57, "layer_size_scal": 57, "4": [57, 58, 71, 76], "scale": [57, 63, 70, 71, 72, 80], "ic": 58, "univers": 58, "south": 58, "pole": 58, "dyntrans_layer_s": 58, "core": 59, "edgedefinit": [60, 61, 62, 63, 65], "how": [60, 61, 65], "drawn": [60, 61, 64, 65], "between": [60, 61, 62, 65, 68, 73, 78, 80], "knnedg": [61, 62], "radialedg": [61, 62], "euclideanedg": [61, 62], "log_fold": [62, 67, 95], "_construct_edg": 62, "nb_nearest_neighbour": [62, 64], "definit": [62, 63, 64, 66, 67, 98], "space": [62, 82], "distanc": [62, 64, 73], "radiu": 62, "sphere": 62, "chosen": [62, 95], "centr": 62, "radial": 62, "center": 62, "sigma": 62, "euclidean": [62, 98], "see": [62, 63, 78, 98, 100], "http": [62, 63, 80, 98], "arxiv": [62, 80], "org": [62, 80, 100], "pdf": 62, "1809": 62, "06166": 62, "hold": 63, "alter": 63, "dure": [63, 70, 71, 72, 78], "node_definit": [63, 64], "edge_definit": 63, "geometri": 63, "nodedefinit": [63, 64, 65, 66], "truth_dict": 63, "custom_label_funct": 63, "loss_weight": [63, 70, 71, 72], "data_path": 63, "shape": [63, 66, 73, 80], "num_nod": 63, "github": [63, 80, 100], "com": [63, 80, 100], "team": [63, 98], "blob": [63, 80], "getting_start": 63, "md": 63, "your": [64, 98, 100], "nodesaspuls": [65, 66], "num_puls": 66, "overridden": 66, "set_number_of_input": 66, "measur": [66, 73], "cherenkov": 66, "radiat": 66, "train_dataload": 67, "val_dataload": 67, "max_epoch": 67, "gpu": [67, 68, 84, 100], "ckpt_path": 67, "log_every_n_step": 67, "gradient_clip_v": 67, "distribution_strategi": [67, 68], "trainer_kwarg": 67, "pytorch_lightn": [67, 95], "trainer": [67, 78, 81], "predict_as_datafram": [67, 68], "additional_attribut": [67, 68, 81], "save_state_dict": 67, "load_state_dict": 67, "karg": 67, "trust": 67, "enough": 67, "eval": [67, 100], "lambda": 67, "consequ": 67, "optimizer_class": 68, "optim": [68, 78], "adam": 68, "optimizer_kwarg": 68, "scheduler_class": 68, "scheduler_kwarg": 68, "scheduler_config": 68, "target_label": [68, 70, 71, 72], "target": [68, 70, 71, 72, 80, 91], "prediction_label": [68, 70, 71, 72], "configure_optim": 68, "shared_step": 68, "batch_idx": 68, "share": 68, "training_step": 68, "train_batch": 68, "validation_step": 68, "val_batch": 68, "compute_loss": [68, 70, 71, 72], "pred": [68, 72], "verbos": [68, 78], "activ": [68, 72, 98, 100], "mode": [68, 72], "deactiv": [68, 72], "multiclassclassificationtask": [69, 70], "binaryclassificationtask": [69, 70], "binaryclassificationtasklogit": [69, 70], "azimuthreconstructionwithkappa": [69, 71], "azimuthreconstruct": [69, 71], "directionreconstructionwithkappa": [69, 71], "zenithreconstruct": [69, 71], "zenithreconstructionwithkappa": [69, 71], "energyreconstruct": [69, 71], "energyreconstructionwithpow": [69, 71], "energyreconstructionwithuncertainti": [69, 71], "vertexreconstruct": [69, 71], "positionreconstruct": [69, 71], "timereconstruct": [69, 71], "inelasticityreconstruct": [69, 71], "identitytask": [69, 70, 72], "arg": [70, 72, 80, 84, 86, 91, 95], "classifi": 70, "untransform": 70, "logit": [70, 80], "affin": [70, 71, 72], "hidden_s": [70, 71, 72], "transform_prediction_and_target": [70, 71, 72], "transform_target": [70, 71, 72], "transform_infer": [70, 71, 72], "transform_support": [70, 71, 72], "binari": [70, 80], "feed": [70, 71, 72], "lossfunct": [70, 71, 72, 77, 80], "auto": [70, 71, 72], "matic": [70, 71, 72], "_pred": [70, 71, 72], "numer": [70, 71, 72], "stabl": [70, 71, 72], "log10": [70, 71, 72, 82], "rather": [70, 71, 72, 95], "conjunct": [70, 71, 72], "invers": [70, 71, 72], "recov": [70, 71, 72], "minimum": [70, 71, 72], "restrict": [70, 71, 72, 80], "invert": [70, 71, 72], "1e6": [70, 71, 72], "default_target_label": [70, 71, 72], "default_prediction_label": [70, 71, 72], "target_pr": 70, "angl": [71, 79], "kappa": [71, 80], "var": 71, "azimuth_pr": 71, "azimuth_kappa": 71, "3d": [71, 80], "vmf": 71, "dir_x_pr": 71, "dir_y_pr": 71, "dir_z_pr": 71, "direction_kappa": 71, "zenith_pr": 71, "zenith_kappa": 71, "energy_pr": 71, "uncertainti": 71, "energy_sigma": 71, "vertex": 71, "position_x_pr": 71, "position_y_pr": 71, "position_z_pr": 71, "interaction_time_pr": 71, "interact": 71, "hadron": 71, "inelasticity_pr": 71, "wrt": 72, "train_ev": 72, "xyzt": 73, "homophili": 73, "notic": [73, 80], "xyz_coord": 73, "pairwis": 73, "nb_dom": 73, "updat": [73, 75, 78], "config_updat": [74, 75], "weightfitt": [74, 75, 77, 82], "contourfitt": [74, 75], "read_entri": [74, 76], "plot_2d_contour": [74, 76], "plot_1d_contour": [74, 76], "contour": [75, 76], "config_path": 75, "new_config_path": 75, "dummy_sect": 75, "temp": 75, "dummi": 75, "section": 75, "header": 75, "configupdat": 75, "programat": 75, "statistical_fit": 75, "fit_weight": [75, 82], "config_outdir": 75, "weight_nam": [75, 82], "pisa_config_dict": 75, "add_to_databas": [75, 82], "flux": 75, "_database_path": 75, "statist": 75, "effect": [75, 78, 98], "account": 75, "systemat": 75, "hypersurfac": 75, "assumpt": 75, "regard": 75, "pipeline_path": 75, "post_fix": 75, "include_retro": 75, "fit_1d_contour": 75, "run_nam": 75, "config_dict": 75, "grid_siz": 75, "theta23_minmax": 75, "36": 75, "54": 75, "dm31_minmax": 75, "1d": [75, 76], "fit_2d_contour": 75, "2d": [75, 76, 80], "content": 76, "contour_data": 76, "xlim": 76, "ylim": 76, "0023799999999999997": 76, "0025499999999999997": 76, "chi2_critical_valu": 76, "width": 76, "height": 76, "path_to_pisa_fit_result": 76, "name_of_my_model_in_fit": 76, "legend": 76, "color": 76, "linestyl": 76, "style": [76, 98], "line": [76, 78, 84], "upper": 76, "axi": 76, "605": 76, "critic": [76, 95], "chi2": 76, "90": 76, "cl": 76, "right": [76, 80], "176": 76, "inch": 76, "388": 76, "706": 76, "abov": [76, 80, 82, 100], "352": 76, "piecewiselinearlr": [77, 78], "progressbar": [77, 78], "mseloss": [77, 80], "rmseloss": [77, 80], "logcoshloss": [77, 80], "crossentropyloss": [77, 80], "binarycrossentropyloss": [77, 80], "logcmk": [77, 80], "vonmisesfisherloss": [77, 80], "vonmisesfisher2dloss": [77, 80], "euclideandistanceloss": [77, 80], "vonmisesfisher3dloss": [77, 80], "make_dataload": [77, 81], "make_train_validation_dataload": [77, 81], "get_predict": [77, 81], "save_result": [77, 81], "uniform": [77, 82], "bjoernlow": [77, 82], "mileston": 78, "factor": 78, "last_epoch": 78, "_lrschedul": 78, "interpol": 78, "linearli": 78, "denot": 78, "multipli": 78, "closest": 78, "vice": 78, "versa": 78, "wrap": [78, 88, 89], "epoch": [78, 84], "print": [78, 95], "stdout": 78, "get_lr": 78, "refresh_r": 78, "process_posit": 78, "tqdmprogressbar": 78, "progress": 78, "bar": 78, "customis": 78, "lightn": 78, "init_validation_tqdm": 78, "overrid": 78, "init_predict_tqdm": 78, "init_test_tqdm": 78, "init_train_tqdm": 78, "get_metr": 78, "on_train_epoch_start": 78, "previou": 78, "behaviour": 78, "on_train_epoch_end": 78, "don": [78, 100], "duplciat": 78, "runtim": [79, 100], "azimuth_kei": 79, "zenith_kei": 79, "access": [79, 100], "azimiuth": 79, "return_el": 80, "elementwis": 80, "term": 80, "squar": 80, "root": [80, 100], "cosh": 80, "act": 80, "cross": 80, "entropi": 80, "num_class": 80, "softmax": 80, "ed": 80, "probabl": 80, "mit": 80, "licens": 80, "copyright": 80, "2019": 80, "ryabinin": 80, "permiss": 80, "herebi": 80, "person": 80, "copi": 80, "document": 80, "deal": 80, "modifi": 80, "publish": 80, "sublicens": 80, "sell": 80, "permit": 80, "whom": 80, "furnish": 80, "so": [80, 100], "subject": 80, "condit": 80, "shall": 80, "substanti": 80, "portion": 80, "THE": 80, "AS": 80, "warranti": 80, "OF": 80, "kind": 80, "OR": 80, "impli": 80, "BUT": 80, "TO": 80, "merchant": 80, "FOR": 80, "particular": [80, 98], "AND": 80, "noninfring": 80, "IN": 80, "NO": 80, "holder": 80, "BE": 80, "liabl": 80, "claim": 80, "damag": 80, "liabil": 80, "action": 80, "contract": 80, "tort": 80, "aris": 80, "WITH": 80, "_____________________": 80, "mryab": 80, "vmf_loss": 80, "master": 80, "py": [80, 100], "bessel": 80, "exponenti": 80, "ditto": 80, "iv": 80, "1812": 80, "04616": 80, "spite": 80, "suggest": 80, "sec": 80, "paper": 80, "m": 80, "correct": 80, "static": [80, 98], "ctx": 80, "backward": 80, "grad_output": 80, "von": 80, "mise": 80, "fisher": 80, "log_cmk_exact": 80, "c_": 80, "exactli": [80, 95], "log_cmk_approx": 80, "approx": 80, "minu": 80, "sign": 80, "log_cmk": 80, "kappa_switch": 80, "diverg": 80, "700": 80, "float64": 80, "precis": 80, "unaccur": 80, "switch": 80, "three": 80, "database_indic": 81, "test_siz": 81, "node_level": 81, "tag": [81, 98, 100], "archiv": 81, "public": 82, "uniformweightfitt": 82, "bin": 82, "privat": 82, "_fit_weight": 82, "sql": 82, "desir": [82, 93], "np": 82, "happen": 82, "x_low": 82, "wherea": 82, "curv": 82, "base_config": [83, 85], "dataset_config": [83, 85], "training_config": [83, 85], "argumentpars": [83, 84], "is_gcd_fil": [83, 93], "is_i3_fil": [83, 93], "has_extens": [83, 93], "find_i3_fil": [83, 93], "has_icecube_packag": [83, 94], "has_torch_packag": [83, 94], "has_pisa_packag": [83, 94], "requires_icecub": [83, 94], "repeatfilt": [83, 95], "eps_lik": [83, 96], "consist": [84, 95, 98], "cli": 84, "pop_default": 84, "usag": 84, "descript": 84, "command": [84, 100], "standard_argu": 84, "home": [84, 100], "runner": 84, "lib": [84, 100], "python3": 84, "training_example_data_sqlit": 84, "earli": 84, "patienc": 84, "narg": 84, "50": 84, "example_energy_reconstruction_model": 84, "num": 84, "fetch": 84, "with_standard_argu": 84, "overwritten": [84, 86], "baseconfig": [85, 86, 87, 88, 89, 91], "get_all_argument_valu": [85, 86], "save_dataset_config": [85, 88], "save_model_config": [85, 89], "traverse_and_appli": [85, 90], "list_all_submodul": [85, 90], "get_all_grapnet_class": [85, 90], "is_graphnet_modul": [85, 90], "is_graphnet_class": [85, 90], "get_graphnet_class": [85, 90], "trainingconfig": [85, 91], "basemodel": [86, 88, 89], "keyword": [86, 91], "validationerror": [86, 91], "pydantic_cor": [86, 91], "__init__": [86, 88, 89, 91, 100], "__pydantic_self__": [86, 91], "dump": [86, 88, 89], "yaml": [86, 87], "as_dict": [86, 88, 89], "classvar": [86, 88, 89, 91], "configdict": [86, 88, 89, 91], "conform": [86, 88, 89, 91], "pydant": [86, 88, 89, 91], "model_field": [86, 88, 89, 91], "fieldinfo": [86, 88, 89, 91], "metadata": [86, 88, 89, 91], "about": [86, 88, 89, 91], "__fields__": [86, 88, 89, 91], "v1": [86, 88, 89, 91, 100], "re": [87, 100], "save_config": 87, "dataconfig": 88, "transpar": [88, 89, 98], "reproduc": [88, 89], "In": [88, 89, 100], "session": [88, 89], "anoth": [88, 89], "you": [88, 89, 98, 100], "still": 88, "csv": 88, "train_select": 88, "test_select": 88, "unambigu": [88, 89], "annot": [88, 89, 91], "nonetyp": 88, "init_fn": [88, 89], "trainabl": 89, "hyperparamet": 89, "instanti": 89, "thu": 89, "fn_kwarg": 90, "structur": 90, "moduletyp": 90, "grapnet": 90, "lookup": 90, "early_stopping_pati": 91, "system": [93, 100], "filenam": 93, "dir": 93, "search": 93, "test_funct": 94, "filter": 95, "repeat": 95, "nb_repeats_allow": 95, "record": 95, "logrecord": 95, "clear": 95, "intuit": 95, "composit": 95, "loggeradapt": 95, "clash": 95, "setlevel": 95, "deleg": 95, "msg": 95, "warn": 95, "info": [95, 100], "debug": 95, "warning_onc": 95, "onc": 95, "handler": 95, "file_handl": 95, "filehandl": 95, "stream_handl": 95, "streamhandl": 95, "assort": 96, "ep": 96, "api": 97, "To": [98, 100], "sure": [98, 100], "smooth": 98, "guidelin": 98, "guid": 98, "encourag": 98, "contributor": 98, "discuss": 98, "bug": 98, "anyth": 98, "place": 98, "describ": 98, "yourself": 98, "ownership": 98, "prioriti": 98, "situat": 98, "lot": 98, "effort": 98, "go": 98, "turn": 98, "outsid": 98, "scope": 98, "better": 98, "fork": 98, "repo": 98, "dedic": 98, "branch": [98, 100], "repositori": 98, "own": [98, 100], "accept": 98, "autom": 98, "review": 98, "pep8": 98, "docstr": 98, "googl": 98, "hint": 98, "adher": 98, "pep": 98, "pylint": 98, "flake8": 98, "black": 98, "well": 98, "recommend": [98, 100], "mypi": 98, "pydocstyl": 98, "docformatt": 98, "commit": 98, "hook": 98, "instal": 98, "come": 98, "pip": [98, 100], "Then": 98, "everytim": 98, "pep257": 98, "concept": 98, "ljvmiranda921": 98, "io": 98, "notebook": 98, "2018": 98, "06": 98, "21": 98, "precommit": 98, "environ": 100, "virtual": 100, "anaconda": 100, "prove": 100, "instruct": 100, "setup": 100, "want": 100, "part": 100, "achiev": 100, "bash": 100, "shell": 100, "cvmf": 100, "opensciencegrid": 100, "py3": 100, "v4": 100, "sh": 100, "rhel_7_x86_64": 100, "metaproject": 100, "env": 100, "alia": 100, "script": 100, "With": 100, "now": 100, "light": 100, "extra": 100, "geometr": 100, "won": 100, "later": 100, "torch_cpu": 100, "txt": 100, "cpu": 100, "torch_gpu": 100, "prefer": 100, "unix": 100, "git": 100, "clone": 100, "usernam": 100, "cd": 100, "conda": 100, "gcc_linux": 100, "64": 100, "gxx_linux": 100, "libgcc": 100, "cudatoolkit": 100, "11": 100, "forg": 100, "torch_maco": 100, "On": 100, "maco": 100, "box": 100, "compil": 100, "gcc": 100, "date": 100, "possibli": 100, "cuda": 100, "toolkit": 100, "recent": 100, "omit": 100, "newer": 100, "export": 100, "ld_library_path": 100, "anaconda3": 100, "miniconda3": 100, "bashrc": 100, "librari": 100, "rm": 100, "asogaard": 100, "latest": 100, "dc423315742c": 100, "01_icetrai": 100, "01_convert_i3_fil": 100, "2023": 100, "01": 100, "24": 100, "41": 100, "27": 100, "graphnet_20230124": 100, "134127": 100, "46": 100, "convert_i3_fil": 100, "ic86": 100, "thread": 100, "00": 100, "79": 100, "42": 100, "26": 100, "413": 100, "88it": 100, "specialis": 100, "ones": 100, "push": 100, "vx": 100}, "objects": {"": [[1, 0, 0, "-", "graphnet"]], "graphnet": [[2, 0, 0, "-", "constants"], [3, 0, 0, "-", "data"], [41, 0, 0, "-", "deployment"], [45, 0, 0, "-", "models"], [74, 0, 0, "-", "pisa"], [77, 0, 0, "-", "training"], [83, 0, 0, "-", "utilities"]], "graphnet.data": [[4, 0, 0, "-", "constants"], [5, 0, 0, "-", "dataconverter"], [6, 0, 0, "-", "dataloader"], [7, 0, 0, "-", "dataset"], [14, 0, 0, "-", "extractors"], [31, 0, 0, "-", "parquet"], [33, 0, 0, "-", "pipeline"], [34, 0, 0, "-", "sqlite"], [37, 0, 0, "-", "utilities"]], "graphnet.data.constants": [[4, 1, 1, "", "FEATURES"], [4, 1, 1, "", "TRUTH"]], "graphnet.data.constants.FEATURES": [[4, 2, 1, "", "DEEPCORE"], [4, 2, 1, "", "ICECUBE86"], [4, 2, 1, "", "KAGGLE"], [4, 2, 1, "", "PROMETHEUS"], [4, 2, 1, "", "UPGRADE"]], "graphnet.data.constants.TRUTH": [[4, 2, 1, "", "DEEPCORE"], [4, 2, 1, "", "ICECUBE86"], [4, 2, 1, "", "KAGGLE"], [4, 2, 1, "", "PROMETHEUS"], [4, 2, 1, "", "UPGRADE"]], "graphnet.data.dataconverter": [[5, 1, 1, "", "DataConverter"], [5, 1, 1, "", "FileSet"], [5, 5, 1, "", "cache_output_files"], [5, 5, 1, "", "init_global_index"]], "graphnet.data.dataconverter.DataConverter": [[5, 3, 1, "", "execute"], [5, 4, 1, "", "file_suffix"], [5, 3, 1, "", "get_map_function"], [5, 3, 1, "", "merge_files"], [5, 3, 1, "", "save_data"]], "graphnet.data.dataconverter.FileSet": [[5, 2, 1, "", "gcd_file"], [5, 2, 1, "", "i3_file"]], "graphnet.data.dataloader": [[6, 1, 1, "", "DataLoader"], [6, 5, 1, "", "collate_fn"], [6, 5, 1, "", "do_shuffle"]], "graphnet.data.dataloader.DataLoader": [[6, 3, 1, "", "from_dataset_config"]], "graphnet.data.dataset": [[8, 0, 0, "-", "dataset"], [9, 0, 0, "-", "parquet"], [11, 0, 0, "-", "sqlite"]], "graphnet.data.dataset.dataset": [[8, 6, 1, "", "ColumnMissingException"], [8, 1, 1, "", "Dataset"], [8, 1, 1, "", "EnsembleDataset"], [8, 5, 1, "", "load_module"], [8, 5, 1, "", "parse_graph_definition"]], "graphnet.data.dataset.dataset.Dataset": [[8, 3, 1, "", "add_label"], [8, 3, 1, "", "concatenate"], [8, 3, 1, "", "from_config"], [8, 4, 1, "", "path"], [8, 3, 1, "", "query_table"], [8, 4, 1, "", "truth_table"]], "graphnet.data.dataset.parquet": [[10, 0, 0, "-", "parquet_dataset"]], "graphnet.data.dataset.parquet.parquet_dataset": [[10, 1, 1, "", "ParquetDataset"]], "graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset": [[10, 3, 1, "", "query_table"]], "graphnet.data.dataset.sqlite": [[12, 0, 0, "-", "sqlite_dataset"], [13, 0, 0, "-", "sqlite_dataset_perturbed"]], "graphnet.data.dataset.sqlite.sqlite_dataset": [[12, 1, 1, "", "SQLiteDataset"]], "graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset": [[12, 3, 1, "", "query_table"]], "graphnet.data.dataset.sqlite.sqlite_dataset_perturbed": [[13, 1, 1, "", "SQLiteDatasetPerturbed"]], "graphnet.data.extractors": [[15, 0, 0, "-", "i3extractor"], [16, 0, 0, "-", "i3featureextractor"], [17, 0, 0, "-", "i3genericextractor"], [18, 0, 0, "-", "i3hybridrecoextractor"], [19, 0, 0, "-", "i3ntmuonlabelsextractor"], [20, 0, 0, "-", "i3particleextractor"], [21, 0, 0, "-", "i3pisaextractor"], [22, 0, 0, "-", "i3quesoextractor"], [23, 0, 0, "-", "i3retroextractor"], [24, 0, 0, "-", "i3splinempeextractor"], [25, 0, 0, "-", "i3truthextractor"], [26, 0, 0, "-", "i3tumextractor"], [27, 0, 0, "-", "utilities"]], "graphnet.data.extractors.i3extractor": [[15, 1, 1, "", "I3Extractor"], [15, 1, 1, "", "I3ExtractorCollection"]], "graphnet.data.extractors.i3extractor.I3Extractor": [[15, 4, 1, "", "name"], [15, 3, 1, "", "set_files"]], "graphnet.data.extractors.i3extractor.I3ExtractorCollection": [[15, 3, 1, "", "set_files"]], "graphnet.data.extractors.i3featureextractor": [[16, 1, 1, "", "I3FeatureExtractor"], [16, 1, 1, "", "I3FeatureExtractorIceCube86"], [16, 1, 1, "", "I3FeatureExtractorIceCubeDeepCore"], [16, 1, 1, "", "I3FeatureExtractorIceCubeUpgrade"], [16, 1, 1, "", "I3PulseNoiseTruthFlagIceCubeUpgrade"]], "graphnet.data.extractors.i3genericextractor": [[17, 1, 1, "", "I3GenericExtractor"]], "graphnet.data.extractors.i3hybridrecoextractor": [[18, 1, 1, "", "I3GalacticPlaneHybridRecoExtractor"]], "graphnet.data.extractors.i3ntmuonlabelsextractor": [[19, 1, 1, "", "I3NTMuonLabelExtractor"]], "graphnet.data.extractors.i3particleextractor": [[20, 1, 1, "", "I3ParticleExtractor"]], "graphnet.data.extractors.i3pisaextractor": [[21, 1, 1, "", "I3PISAExtractor"]], "graphnet.data.extractors.i3quesoextractor": [[22, 1, 1, "", "I3QUESOExtractor"]], "graphnet.data.extractors.i3retroextractor": [[23, 1, 1, "", "I3RetroExtractor"]], "graphnet.data.extractors.i3splinempeextractor": [[24, 1, 1, "", "I3SplineMPEICExtractor"]], "graphnet.data.extractors.i3truthextractor": [[25, 1, 1, "", "I3TruthExtractor"]], "graphnet.data.extractors.i3tumextractor": [[26, 1, 1, "", "I3TUMExtractor"]], "graphnet.data.extractors.utilities": [[28, 0, 0, "-", "collections"], [29, 0, 0, "-", "frames"], [30, 0, 0, "-", "types"]], "graphnet.data.extractors.utilities.collections": [[28, 5, 1, "", "flatten_nested_dictionary"], [28, 5, 1, "", "serialise"], [28, 5, 1, "", "transpose_list_of_dicts"]], "graphnet.data.extractors.utilities.frames": [[29, 5, 1, "", "frame_is_montecarlo"], [29, 5, 1, "", "frame_is_noise"], [29, 5, 1, "", "get_om_keys_and_pulseseries"]], "graphnet.data.extractors.utilities.types": [[30, 5, 1, "", "break_cyclic_recursion"], [30, 5, 1, "", "cast_object_to_pure_python"], [30, 5, 1, "", "cast_pulse_series_to_pure_python"], [30, 5, 1, "", "get_member_variables"], [30, 5, 1, "", "is_boost_class"], [30, 5, 1, "", "is_boost_enum"], [30, 5, 1, "", "is_icecube_class"], [30, 5, 1, "", "is_method"], [30, 5, 1, "", "is_type"]], "graphnet.data.parquet": [[32, 0, 0, "-", "parquet_dataconverter"]], "graphnet.data.parquet.parquet_dataconverter": [[32, 1, 1, "", "ParquetDataConverter"]], "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter": [[32, 2, 1, "", "file_suffix"], [32, 3, 1, "", "merge_files"], [32, 3, 1, "", "save_data"]], "graphnet.data.pipeline": [[33, 1, 1, "", "InSQLitePipeline"]], "graphnet.data.sqlite": [[35, 0, 0, "-", "sqlite_dataconverter"], [36, 0, 0, "-", "sqlite_utilities"]], "graphnet.data.sqlite.sqlite_dataconverter": [[35, 1, 1, "", "SQLiteDataConverter"], [35, 5, 1, "", "construct_dataframe"], [35, 5, 1, "", "is_mc_tree"], [35, 5, 1, "", "is_pulse_map"]], "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter": [[35, 3, 1, "", "any_pulsemap_is_non_empty"], [35, 2, 1, "", "file_suffix"], [35, 3, 1, "", "merge_files"], [35, 3, 1, "", "save_data"]], "graphnet.data.sqlite.sqlite_utilities": [[36, 5, 1, "", "attach_index"], [36, 5, 1, "", "create_table"], [36, 5, 1, "", "create_table_and_save_to_sql"], [36, 5, 1, "", "database_exists"], [36, 5, 1, "", "database_table_exists"], [36, 5, 1, "", "run_sql_code"], [36, 5, 1, "", "save_to_sql"]], "graphnet.data.utilities": [[38, 0, 0, "-", "parquet_to_sqlite"], [39, 0, 0, "-", "random"], [40, 0, 0, "-", "string_selection_resolver"]], "graphnet.data.utilities.parquet_to_sqlite": [[38, 1, 1, "", "ParquetToSQLiteConverter"]], "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter": [[38, 3, 1, "", "run"]], "graphnet.data.utilities.random": [[39, 5, 1, "", "pairwise_shuffle"]], "graphnet.data.utilities.string_selection_resolver": [[40, 1, 1, "", "StringSelectionResolver"]], "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver": [[40, 3, 1, "", "resolve"]], "graphnet.deployment.i3modules": [[44, 0, 0, "-", "graphnet_module"]], "graphnet.deployment.i3modules.graphnet_module": [[44, 1, 1, "", "GraphNeTI3Module"], [44, 1, 1, "", "I3InferenceModule"], [44, 1, 1, "", "I3PulseCleanerModule"]], "graphnet.models": [[46, 0, 0, "-", "coarsening"], [47, 0, 0, "-", "components"], [50, 0, 0, "-", "detector"], [54, 0, 0, "-", "gnn"], [60, 0, 0, "-", "graphs"], [67, 0, 0, "-", "model"], [68, 0, 0, "-", "standard_model"], [69, 0, 0, "-", "task"], [73, 0, 0, "-", "utils"]], "graphnet.models.coarsening": [[46, 1, 1, "", "AttributeCoarsening"], [46, 1, 1, "", "Coarsening"], [46, 1, 1, "", "CustomDOMCoarsening"], [46, 1, 1, "", "DOMAndTimeWindowCoarsening"], [46, 1, 1, "", "DOMCoarsening"], [46, 5, 1, "", "unbatch_edge_index"]], "graphnet.models.coarsening.Coarsening": [[46, 3, 1, "", "forward"], [46, 2, 1, "", "reduce_options"]], "graphnet.models.components": [[48, 0, 0, "-", "layers"], [49, 0, 0, "-", "pool"]], "graphnet.models.components.layers": [[48, 1, 1, "", "DynEdgeConv"], [48, 1, 1, "", "DynTrans"], [48, 1, 1, "", "EdgeConvTito"]], "graphnet.models.components.layers.DynEdgeConv": [[48, 3, 1, "", "forward"]], "graphnet.models.components.layers.DynTrans": [[48, 3, 1, "", "forward"]], "graphnet.models.components.layers.EdgeConvTito": [[48, 3, 1, "", "forward"], [48, 3, 1, "", "message"], [48, 3, 1, "", "reset_parameters"]], "graphnet.models.components.pool": [[49, 5, 1, "", "group_by"], [49, 5, 1, "", "group_pulses_to_dom"], [49, 5, 1, "", "group_pulses_to_pmt"], [49, 5, 1, "", "min_pool"], [49, 5, 1, "", "min_pool_x"], [49, 5, 1, "", "std_pool"], [49, 5, 1, "", "std_pool_x"], [49, 5, 1, "", "sum_pool"], [49, 5, 1, "", "sum_pool_and_distribute"], [49, 5, 1, "", "sum_pool_x"]], "graphnet.models.detector": [[51, 0, 0, "-", "detector"], [52, 0, 0, "-", "icecube"], [53, 0, 0, "-", "prometheus"]], "graphnet.models.detector.detector": [[51, 1, 1, "", "Detector"]], "graphnet.models.detector.detector.Detector": [[51, 3, 1, "", "feature_map"], [51, 3, 1, "", "forward"]], "graphnet.models.detector.icecube": [[52, 1, 1, "", "IceCube86"], [52, 1, 1, "", "IceCubeDeepCore"], [52, 1, 1, "", "IceCubeKaggle"], [52, 1, 1, "", "IceCubeUpgrade"]], "graphnet.models.detector.icecube.IceCube86": [[52, 3, 1, "", "feature_map"]], "graphnet.models.detector.icecube.IceCubeDeepCore": [[52, 3, 1, "", "feature_map"]], "graphnet.models.detector.icecube.IceCubeKaggle": [[52, 3, 1, "", "feature_map"]], "graphnet.models.detector.icecube.IceCubeUpgrade": [[52, 3, 1, "", "feature_map"]], "graphnet.models.detector.prometheus": [[53, 1, 1, "", "Prometheus"]], "graphnet.models.detector.prometheus.Prometheus": [[53, 3, 1, "", "feature_map"]], "graphnet.models.gnn": [[55, 0, 0, "-", "convnet"], [56, 0, 0, "-", "dynedge"], [57, 0, 0, "-", "dynedge_jinst"], [58, 0, 0, "-", "dynedge_kaggle_tito"], [59, 0, 0, "-", "gnn"]], "graphnet.models.gnn.convnet": [[55, 1, 1, "", "ConvNet"]], "graphnet.models.gnn.convnet.ConvNet": [[55, 3, 1, "", "forward"]], "graphnet.models.gnn.dynedge": [[56, 1, 1, "", "DynEdge"]], "graphnet.models.gnn.dynedge.DynEdge": [[56, 3, 1, "", "forward"]], "graphnet.models.gnn.dynedge_jinst": [[57, 1, 1, "", "DynEdgeJINST"]], "graphnet.models.gnn.dynedge_jinst.DynEdgeJINST": [[57, 3, 1, "", "forward"]], "graphnet.models.gnn.dynedge_kaggle_tito": [[58, 1, 1, "", "DynEdgeTITO"]], "graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO": [[58, 3, 1, "", "forward"]], "graphnet.models.gnn.gnn": [[59, 1, 1, "", "GNN"]], "graphnet.models.gnn.gnn.GNN": [[59, 3, 1, "", "forward"], [59, 4, 1, "", "nb_inputs"], [59, 4, 1, "", "nb_outputs"]], "graphnet.models.graphs": [[61, 0, 0, "-", "edges"], [63, 0, 0, "-", "graph_definition"], [64, 0, 0, "-", "graphs"], [65, 0, 0, "-", "nodes"]], "graphnet.models.graphs.edges": [[62, 0, 0, "-", "edges"]], "graphnet.models.graphs.edges.edges": [[62, 1, 1, "", "EdgeDefinition"], [62, 1, 1, "", "EuclideanEdges"], [62, 1, 1, "", "KNNEdges"], [62, 1, 1, "", "RadialEdges"]], "graphnet.models.graphs.edges.edges.EdgeDefinition": [[62, 3, 1, "", "forward"]], "graphnet.models.graphs.graph_definition": [[63, 1, 1, "", "GraphDefinition"]], "graphnet.models.graphs.graph_definition.GraphDefinition": [[63, 3, 1, "", "forward"]], "graphnet.models.graphs.graphs": [[64, 1, 1, "", "KNNGraph"]], "graphnet.models.graphs.nodes": [[66, 0, 0, "-", "nodes"]], "graphnet.models.graphs.nodes.nodes": [[66, 1, 1, "", "NodeDefinition"], [66, 1, 1, "", "NodesAsPulses"]], "graphnet.models.graphs.nodes.nodes.NodeDefinition": [[66, 3, 1, "", "forward"], [66, 4, 1, "", "nb_outputs"], [66, 3, 1, "", "set_number_of_inputs"]], "graphnet.models.model": [[67, 1, 1, "", "Model"]], "graphnet.models.model.Model": [[67, 3, 1, "", "fit"], [67, 3, 1, "", "forward"], [67, 3, 1, "", "from_config"], [67, 3, 1, "", "load"], [67, 3, 1, "", "load_state_dict"], [67, 3, 1, "", "predict"], [67, 3, 1, "", "predict_as_dataframe"], [67, 3, 1, "", "save"], [67, 3, 1, "", "save_state_dict"]], "graphnet.models.standard_model": [[68, 1, 1, "", "StandardModel"]], "graphnet.models.standard_model.StandardModel": [[68, 3, 1, "", "compute_loss"], [68, 3, 1, "", "configure_optimizers"], [68, 3, 1, "", "forward"], [68, 3, 1, "", "inference"], [68, 3, 1, "", "predict"], [68, 3, 1, "", "predict_as_dataframe"], [68, 4, 1, "", "prediction_labels"], [68, 3, 1, "", "shared_step"], [68, 4, 1, "", "target_labels"], [68, 3, 1, "", "train"], [68, 3, 1, "", "training_step"], [68, 3, 1, "", "validation_step"]], "graphnet.models.task": [[70, 0, 0, "-", "classification"], [71, 0, 0, "-", "reconstruction"], [72, 0, 0, "-", "task"]], "graphnet.models.task.classification": [[70, 1, 1, "", "BinaryClassificationTask"], [70, 1, 1, "", "BinaryClassificationTaskLogits"], [70, 1, 1, "", "MulticlassClassificationTask"]], "graphnet.models.task.classification.BinaryClassificationTask": [[70, 2, 1, "", "default_prediction_labels"], [70, 2, 1, "", "default_target_labels"], [70, 2, 1, "", "nb_inputs"]], "graphnet.models.task.classification.BinaryClassificationTaskLogits": [[70, 2, 1, "", "default_prediction_labels"], [70, 2, 1, "", "default_target_labels"], [70, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction": [[71, 1, 1, "", "AzimuthReconstruction"], [71, 1, 1, "", "AzimuthReconstructionWithKappa"], [71, 1, 1, "", "DirectionReconstructionWithKappa"], [71, 1, 1, "", "EnergyReconstruction"], [71, 1, 1, "", "EnergyReconstructionWithPower"], [71, 1, 1, "", "EnergyReconstructionWithUncertainty"], [71, 1, 1, "", "InelasticityReconstruction"], [71, 1, 1, "", "PositionReconstruction"], [71, 1, 1, "", "TimeReconstruction"], [71, 1, 1, "", "VertexReconstruction"], [71, 1, 1, "", "ZenithReconstruction"], [71, 1, 1, "", "ZenithReconstructionWithKappa"]], "graphnet.models.task.reconstruction.AzimuthReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.EnergyReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.EnergyReconstructionWithPower": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.InelasticityReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.PositionReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.TimeReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.VertexReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.ZenithReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.task": [[72, 1, 1, "", "IdentityTask"], [72, 1, 1, "", "Task"]], "graphnet.models.task.task.IdentityTask": [[72, 4, 1, "", "default_prediction_labels"], [72, 4, 1, "", "default_target_labels"], [72, 4, 1, "", "nb_inputs"]], "graphnet.models.task.task.Task": [[72, 3, 1, "", "compute_loss"], [72, 4, 1, "", "default_prediction_labels"], [72, 4, 1, "", "default_target_labels"], [72, 3, 1, "", "forward"], [72, 3, 1, "", "inference"], [72, 4, 1, "", "nb_inputs"], [72, 3, 1, "", "train_eval"]], "graphnet.models.utils": [[73, 5, 1, "", "calculate_distance_matrix"], [73, 5, 1, "", "calculate_xyzt_homophily"], [73, 5, 1, "", "knn_graph_batch"]], "graphnet.pisa": [[75, 0, 0, "-", "fitting"], [76, 0, 0, "-", "plotting"]], "graphnet.pisa.fitting": [[75, 1, 1, "", "ContourFitter"], [75, 1, 1, "", "WeightFitter"], [75, 5, 1, "", "config_updater"]], "graphnet.pisa.fitting.ContourFitter": [[75, 3, 1, "", "fit_1d_contour"], [75, 3, 1, "", "fit_2d_contour"]], "graphnet.pisa.fitting.WeightFitter": [[75, 3, 1, "", "fit_weights"]], "graphnet.pisa.plotting": [[76, 5, 1, "", "plot_1D_contour"], [76, 5, 1, "", "plot_2D_contour"], [76, 5, 1, "", "read_entry"]], "graphnet.training": [[78, 0, 0, "-", "callbacks"], [79, 0, 0, "-", "labels"], [80, 0, 0, "-", "loss_functions"], [81, 0, 0, "-", "utils"], [82, 0, 0, "-", "weight_fitting"]], "graphnet.training.callbacks": [[78, 1, 1, "", "PiecewiseLinearLR"], [78, 1, 1, "", "ProgressBar"]], "graphnet.training.callbacks.PiecewiseLinearLR": [[78, 3, 1, "", "get_lr"]], "graphnet.training.callbacks.ProgressBar": [[78, 3, 1, "", "get_metrics"], [78, 3, 1, "", "init_predict_tqdm"], [78, 3, 1, "", "init_test_tqdm"], [78, 3, 1, "", "init_train_tqdm"], [78, 3, 1, "", "init_validation_tqdm"], [78, 3, 1, "", "on_train_epoch_end"], [78, 3, 1, "", "on_train_epoch_start"]], "graphnet.training.labels": [[79, 1, 1, "", "Direction"], [79, 1, 1, "", "Label"]], "graphnet.training.labels.Label": [[79, 4, 1, "", "key"]], "graphnet.training.loss_functions": [[80, 1, 1, "", "BinaryCrossEntropyLoss"], [80, 1, 1, "", "CrossEntropyLoss"], [80, 1, 1, "", "EuclideanDistanceLoss"], [80, 1, 1, "", "LogCMK"], [80, 1, 1, "", "LogCoshLoss"], [80, 1, 1, "", "LossFunction"], [80, 1, 1, "", "MSELoss"], [80, 1, 1, "", "RMSELoss"], [80, 1, 1, "", "VonMisesFisher2DLoss"], [80, 1, 1, "", "VonMisesFisher3DLoss"], [80, 1, 1, "", "VonMisesFisherLoss"]], "graphnet.training.loss_functions.LogCMK": [[80, 3, 1, "", "backward"], [80, 3, 1, "", "forward"]], "graphnet.training.loss_functions.LossFunction": [[80, 3, 1, "", "forward"]], "graphnet.training.loss_functions.VonMisesFisherLoss": [[80, 3, 1, "", "log_cmk"], [80, 3, 1, "", "log_cmk_approx"], [80, 3, 1, "", "log_cmk_exact"]], "graphnet.training.utils": [[81, 5, 1, "", "collate_fn"], [81, 5, 1, "", "get_predictions"], [81, 5, 1, "", "make_dataloader"], [81, 5, 1, "", "make_train_validation_dataloader"], [81, 5, 1, "", "save_results"]], "graphnet.training.weight_fitting": [[82, 1, 1, "", "BjoernLow"], [82, 1, 1, "", "Uniform"], [82, 1, 1, "", "WeightFitter"]], "graphnet.training.weight_fitting.WeightFitter": [[82, 3, 1, "", "fit"]], "graphnet.utilities": [[84, 0, 0, "-", "argparse"], [85, 0, 0, "-", "config"], [92, 0, 0, "-", "decorators"], [93, 0, 0, "-", "filesys"], [94, 0, 0, "-", "imports"], [95, 0, 0, "-", "logging"], [96, 0, 0, "-", "maths"]], "graphnet.utilities.argparse": [[84, 1, 1, "", "ArgumentParser"], [84, 1, 1, "", "Options"]], "graphnet.utilities.argparse.ArgumentParser": [[84, 2, 1, "", "standard_arguments"], [84, 3, 1, "", "with_standard_arguments"]], "graphnet.utilities.argparse.Options": [[84, 3, 1, "", "contains"], [84, 3, 1, "", "pop_default"]], "graphnet.utilities.config": [[86, 0, 0, "-", "base_config"], [87, 0, 0, "-", "configurable"], [88, 0, 0, "-", "dataset_config"], [89, 0, 0, "-", "model_config"], [90, 0, 0, "-", "parsing"], [91, 0, 0, "-", "training_config"]], "graphnet.utilities.config.base_config": [[86, 1, 1, "", "BaseConfig"], [86, 5, 1, "", "get_all_argument_values"]], "graphnet.utilities.config.base_config.BaseConfig": [[86, 3, 1, "", "as_dict"], [86, 3, 1, "", "dump"], [86, 3, 1, "", "load"], [86, 2, 1, "", "model_config"], [86, 2, 1, "", "model_fields"]], "graphnet.utilities.config.configurable": [[87, 1, 1, "", "Configurable"]], "graphnet.utilities.config.configurable.Configurable": [[87, 4, 1, "", "config"], [87, 3, 1, "", "from_config"], [87, 3, 1, "", "save_config"]], "graphnet.utilities.config.dataset_config": [[88, 1, 1, "", "DatasetConfig"], [88, 5, 1, "", "save_dataset_config"]], "graphnet.utilities.config.dataset_config.DatasetConfig": [[88, 3, 1, "", "as_dict"], [88, 2, 1, "", "features"], [88, 2, 1, "", "graph_definition"], [88, 2, 1, "", "index_column"], [88, 2, 1, "", "loss_weight_column"], [88, 2, 1, "", "loss_weight_default_value"], [88, 2, 1, "", "loss_weight_table"], [88, 2, 1, "", "model_config"], [88, 2, 1, "", "model_fields"], [88, 2, 1, "", "node_truth"], [88, 2, 1, "", "node_truth_table"], [88, 2, 1, "", "path"], [88, 2, 1, "", "pulsemaps"], [88, 2, 1, "", "seed"], [88, 2, 1, "", "selection"], [88, 2, 1, "", "string_selection"], [88, 2, 1, "", "truth"], [88, 2, 1, "", "truth_table"]], "graphnet.utilities.config.model_config": [[89, 1, 1, "", "ModelConfig"], [89, 5, 1, "", "save_model_config"]], "graphnet.utilities.config.model_config.ModelConfig": [[89, 2, 1, "", "arguments"], [89, 3, 1, "", "as_dict"], [89, 2, 1, "", "class_name"], [89, 2, 1, "", "model_config"], [89, 2, 1, "", "model_fields"]], "graphnet.utilities.config.parsing": [[90, 5, 1, "", "get_all_grapnet_classes"], [90, 5, 1, "", "get_graphnet_classes"], [90, 5, 1, "", "is_graphnet_class"], [90, 5, 1, "", "is_graphnet_module"], [90, 5, 1, "", "list_all_submodules"], [90, 5, 1, "", "traverse_and_apply"]], "graphnet.utilities.config.training_config": [[91, 1, 1, "", "TrainingConfig"]], "graphnet.utilities.config.training_config.TrainingConfig": [[91, 2, 1, "", "dataloader"], [91, 2, 1, "", "early_stopping_patience"], [91, 2, 1, "", "fit"], [91, 2, 1, "", "model_config"], [91, 2, 1, "", "model_fields"], [91, 2, 1, "", "target"]], "graphnet.utilities.filesys": [[93, 5, 1, "", "find_i3_files"], [93, 5, 1, "", "has_extension"], [93, 5, 1, "", "is_gcd_file"], [93, 5, 1, "", "is_i3_file"]], "graphnet.utilities.imports": [[94, 5, 1, "", "has_icecube_package"], [94, 5, 1, "", "has_pisa_package"], [94, 5, 1, "", "has_torch_package"], [94, 5, 1, "", "requires_icecube"]], "graphnet.utilities.logging": [[95, 1, 1, "", "Logger"], [95, 1, 1, "", "RepeatFilter"]], "graphnet.utilities.logging.Logger": [[95, 3, 1, "", "critical"], [95, 3, 1, "", "debug"], [95, 3, 1, "", "error"], [95, 4, 1, "", "file_handlers"], [95, 4, 1, "", "handlers"], [95, 3, 1, "", "info"], [95, 3, 1, "", "setLevel"], [95, 4, 1, "", "stream_handlers"], [95, 3, 1, "", "warning"], [95, 3, 1, "", "warning_once"]], "graphnet.utilities.logging.RepeatFilter": [[95, 3, 1, "", "filter"], [95, 2, 1, "", "nb_repeats_allowed"]], "graphnet.utilities.maths": [[96, 5, 1, "", "eps_like"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:attribute", "3": "py:method", "4": "py:property", "5": "py:function", "6": "py:exception"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "attribute", "Python attribute"], "3": ["py", "method", "Python method"], "4": ["py", "property", "Python property"], "5": ["py", "function", "Python function"], "6": ["py", "exception", "Python exception"]}, "titleterms": {"about": [0, 99], "impact": [0, 99], "usag": [0, 99], "acknowledg": [0, 99], "api": 1, "constant": [2, 4], "data": 3, "dataconvert": 5, "dataload": 6, "dataset": [7, 8], "parquet": [9, 31], "parquet_dataset": 10, "sqlite": [11, 34], "sqlite_dataset": 12, "sqlite_dataset_perturb": 13, "extractor": 14, "i3extractor": 15, "i3featureextractor": 16, "i3genericextractor": 17, "i3hybridrecoextractor": 18, "i3ntmuonlabelsextractor": 19, "i3particleextractor": 20, "i3pisaextractor": 21, "i3quesoextractor": 22, "i3retroextractor": 23, "i3splinempeextractor": 24, "i3truthextractor": 25, "i3tumextractor": 26, "util": [27, 37, 73, 81, 83], "collect": 28, "frame": 29, "type": 30, "parquet_dataconvert": 32, "pipelin": 33, "sqlite_dataconvert": 35, "sqlite_util": 36, "parquet_to_sqlit": 38, "random": 39, "string_selection_resolv": 40, "deploy": [41, 43], "i3modul": 42, "graphnet_modul": 44, "model": [45, 67], "coarsen": 46, "compon": 47, "layer": 48, "pool": 49, "detector": [50, 51], "icecub": 52, "prometheu": 53, "gnn": [54, 59], "convnet": 55, "dynedg": 56, "dynedge_jinst": 57, "dynedge_kaggle_tito": 58, "graph": [60, 64], "edg": [61, 62], "graph_definit": 63, "node": [65, 66], "standard_model": 68, "task": [69, 72], "classif": 70, "reconstruct": 71, "pisa": 74, "fit": 75, "plot": 76, "train": 77, "callback": 78, "label": 79, "loss_funct": 80, "weight_fit": 82, "argpars": 84, "config": 85, "base_config": 86, "configur": 87, "dataset_config": 88, "model_config": 89, "pars": 90, "training_config": 91, "decor": 92, "filesi": 93, "import": 94, "log": 95, "math": 96, "src": 97, "contribut": 98, "github": 98, "issu": 98, "pull": 98, "request": 98, "convent": 98, "code": 98, "qualiti": 98, "instal": 100, "icetrai": 100, "stand": 100, "alon": 100, "run": 100, "docker": 100}, "envversion": {"sphinx.domains.c": 3, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 9, "sphinx.domains.index": 1, "sphinx.domains.javascript": 3, "sphinx.domains.math": 2, "sphinx.domains.python": 4, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx.ext.todo": 2, "sphinx.ext.viewcode": 1, "sphinx": 60}, "alltitles": {"About": [[0, "about"], [99, "about"]], "Impact": [[0, "impact"], [99, "impact"]], "Usage": [[0, "usage"], [99, "usage"]], "Acknowledgements": [[0, "acknowledgements"], [99, "acknowledgements"]], "API": [[1, "module-graphnet"]], "constants": [[2, "module-graphnet.constants"], [4, "module-graphnet.data.constants"]], "data": [[3, "module-graphnet.data"]], "dataconverter": [[5, "module-graphnet.data.dataconverter"]], "dataloader": [[6, "module-graphnet.data.dataloader"]], "dataset": [[7, "module-graphnet.data.dataset"], [8, "module-graphnet.data.dataset.dataset"]], "parquet": [[9, "module-graphnet.data.dataset.parquet"], [31, "module-graphnet.data.parquet"]], "parquet_dataset": [[10, "module-graphnet.data.dataset.parquet.parquet_dataset"]], "sqlite": [[11, "module-graphnet.data.dataset.sqlite"], [34, "module-graphnet.data.sqlite"]], "sqlite_dataset": [[12, "module-graphnet.data.dataset.sqlite.sqlite_dataset"]], "sqlite_dataset_perturbed": [[13, "module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed"]], "extractors": [[14, "module-graphnet.data.extractors"]], "i3extractor": [[15, "module-graphnet.data.extractors.i3extractor"]], "i3featureextractor": [[16, "module-graphnet.data.extractors.i3featureextractor"]], "i3genericextractor": [[17, "module-graphnet.data.extractors.i3genericextractor"]], "i3hybridrecoextractor": [[18, "module-graphnet.data.extractors.i3hybridrecoextractor"]], "i3ntmuonlabelsextractor": [[19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"]], "i3particleextractor": [[20, "module-graphnet.data.extractors.i3particleextractor"]], "i3pisaextractor": [[21, "module-graphnet.data.extractors.i3pisaextractor"]], "i3quesoextractor": [[22, "module-graphnet.data.extractors.i3quesoextractor"]], "i3retroextractor": [[23, "module-graphnet.data.extractors.i3retroextractor"]], "i3splinempeextractor": [[24, "module-graphnet.data.extractors.i3splinempeextractor"]], "i3truthextractor": [[25, "module-graphnet.data.extractors.i3truthextractor"]], "i3tumextractor": [[26, "module-graphnet.data.extractors.i3tumextractor"]], "utilities": [[27, "module-graphnet.data.extractors.utilities"], [37, "module-graphnet.data.utilities"], [83, "module-graphnet.utilities"]], "collections": [[28, "module-graphnet.data.extractors.utilities.collections"]], "frames": [[29, "module-graphnet.data.extractors.utilities.frames"]], "types": [[30, "module-graphnet.data.extractors.utilities.types"]], "parquet_dataconverter": [[32, "module-graphnet.data.parquet.parquet_dataconverter"]], "pipeline": [[33, "module-graphnet.data.pipeline"]], "sqlite_dataconverter": [[35, "module-graphnet.data.sqlite.sqlite_dataconverter"]], "sqlite_utilities": [[36, "module-graphnet.data.sqlite.sqlite_utilities"]], "parquet_to_sqlite": [[38, "module-graphnet.data.utilities.parquet_to_sqlite"]], "random": [[39, "module-graphnet.data.utilities.random"]], "string_selection_resolver": [[40, "module-graphnet.data.utilities.string_selection_resolver"]], "deployment": [[41, "module-graphnet.deployment"]], "i3modules": [[42, "i3modules"]], "deployer": [[43, "deployer"]], "graphnet_module": [[44, "module-graphnet.deployment.i3modules.graphnet_module"]], "models": [[45, "module-graphnet.models"]], "coarsening": [[46, "module-graphnet.models.coarsening"]], "components": [[47, "module-graphnet.models.components"]], "layers": [[48, "module-graphnet.models.components.layers"]], "pool": [[49, "module-graphnet.models.components.pool"]], "detector": [[50, "module-graphnet.models.detector"], [51, "module-graphnet.models.detector.detector"]], "icecube": [[52, "module-graphnet.models.detector.icecube"]], "prometheus": [[53, "module-graphnet.models.detector.prometheus"]], "gnn": [[54, "module-graphnet.models.gnn"], [59, "module-graphnet.models.gnn.gnn"]], "convnet": [[55, "module-graphnet.models.gnn.convnet"]], "dynedge": [[56, "module-graphnet.models.gnn.dynedge"]], "dynedge_jinst": [[57, "module-graphnet.models.gnn.dynedge_jinst"]], "dynedge_kaggle_tito": [[58, "module-graphnet.models.gnn.dynedge_kaggle_tito"]], "graphs": [[60, "module-graphnet.models.graphs"], [64, "module-graphnet.models.graphs.graphs"]], "edges": [[61, "module-graphnet.models.graphs.edges"], [62, "module-graphnet.models.graphs.edges.edges"]], "graph_definition": [[63, "module-graphnet.models.graphs.graph_definition"]], "nodes": [[65, "module-graphnet.models.graphs.nodes"], [66, "module-graphnet.models.graphs.nodes.nodes"]], "model": [[67, "module-graphnet.models.model"]], "standard_model": [[68, "module-graphnet.models.standard_model"]], "task": [[69, "module-graphnet.models.task"], [72, "module-graphnet.models.task.task"]], "classification": [[70, "module-graphnet.models.task.classification"]], "reconstruction": [[71, "module-graphnet.models.task.reconstruction"]], "utils": [[73, "module-graphnet.models.utils"], [81, "module-graphnet.training.utils"]], "pisa": [[74, "module-graphnet.pisa"]], "fitting": [[75, "module-graphnet.pisa.fitting"]], "plotting": [[76, "module-graphnet.pisa.plotting"]], "training": [[77, "module-graphnet.training"]], "callbacks": [[78, "module-graphnet.training.callbacks"]], "labels": [[79, "module-graphnet.training.labels"]], "loss_functions": [[80, "module-graphnet.training.loss_functions"]], "weight_fitting": [[82, "module-graphnet.training.weight_fitting"]], "argparse": [[84, "module-graphnet.utilities.argparse"]], "config": [[85, "module-graphnet.utilities.config"]], "base_config": [[86, "module-graphnet.utilities.config.base_config"]], "configurable": [[87, "module-graphnet.utilities.config.configurable"]], "dataset_config": [[88, "module-graphnet.utilities.config.dataset_config"]], "model_config": [[89, "module-graphnet.utilities.config.model_config"]], "parsing": [[90, "module-graphnet.utilities.config.parsing"]], "training_config": [[91, "module-graphnet.utilities.config.training_config"]], "decorators": [[92, "module-graphnet.utilities.decorators"]], "filesys": [[93, "module-graphnet.utilities.filesys"]], "imports": [[94, "module-graphnet.utilities.imports"]], "logging": [[95, "module-graphnet.utilities.logging"]], "maths": [[96, "module-graphnet.utilities.maths"]], "src": [[97, "src"]], "Contribute": [[98, "contribute"]], "GitHub issues": [[98, "github-issues"]], "Pull requests": [[98, "pull-requests"]], "Conventions": [[98, "conventions"]], "Code quality": [[98, "code-quality"]], "Install": [[100, "install"]], "Installing with IceTray": [[100, "installing-with-icetray"]], "Installing stand-alone": [[100, "installing-stand-alone"]], "Running in Docker": [[100, "running-in-docker"]]}, "indexentries": {"graphnet": [[1, "module-graphnet"]], "module": [[1, "module-graphnet"], [2, "module-graphnet.constants"], [3, "module-graphnet.data"], [4, "module-graphnet.data.constants"], [5, "module-graphnet.data.dataconverter"], [6, "module-graphnet.data.dataloader"], [7, "module-graphnet.data.dataset"], [8, "module-graphnet.data.dataset.dataset"], [9, "module-graphnet.data.dataset.parquet"], [10, "module-graphnet.data.dataset.parquet.parquet_dataset"], [11, "module-graphnet.data.dataset.sqlite"], [12, "module-graphnet.data.dataset.sqlite.sqlite_dataset"], [13, "module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed"], [14, "module-graphnet.data.extractors"], [15, "module-graphnet.data.extractors.i3extractor"], [16, "module-graphnet.data.extractors.i3featureextractor"], [17, "module-graphnet.data.extractors.i3genericextractor"], [18, "module-graphnet.data.extractors.i3hybridrecoextractor"], [19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"], [20, "module-graphnet.data.extractors.i3particleextractor"], [21, "module-graphnet.data.extractors.i3pisaextractor"], [22, "module-graphnet.data.extractors.i3quesoextractor"], [23, "module-graphnet.data.extractors.i3retroextractor"], [24, "module-graphnet.data.extractors.i3splinempeextractor"], [25, "module-graphnet.data.extractors.i3truthextractor"], [26, "module-graphnet.data.extractors.i3tumextractor"], [27, "module-graphnet.data.extractors.utilities"], [28, "module-graphnet.data.extractors.utilities.collections"], [29, "module-graphnet.data.extractors.utilities.frames"], [30, "module-graphnet.data.extractors.utilities.types"], [31, "module-graphnet.data.parquet"], [32, "module-graphnet.data.parquet.parquet_dataconverter"], [33, "module-graphnet.data.pipeline"], [34, "module-graphnet.data.sqlite"], [35, "module-graphnet.data.sqlite.sqlite_dataconverter"], [36, "module-graphnet.data.sqlite.sqlite_utilities"], [37, "module-graphnet.data.utilities"], [38, "module-graphnet.data.utilities.parquet_to_sqlite"], [39, "module-graphnet.data.utilities.random"], [40, "module-graphnet.data.utilities.string_selection_resolver"], [41, "module-graphnet.deployment"], [44, "module-graphnet.deployment.i3modules.graphnet_module"], [45, "module-graphnet.models"], [46, "module-graphnet.models.coarsening"], [47, "module-graphnet.models.components"], [48, "module-graphnet.models.components.layers"], [49, "module-graphnet.models.components.pool"], [50, "module-graphnet.models.detector"], [51, "module-graphnet.models.detector.detector"], [52, "module-graphnet.models.detector.icecube"], [53, "module-graphnet.models.detector.prometheus"], [54, "module-graphnet.models.gnn"], [55, "module-graphnet.models.gnn.convnet"], [56, "module-graphnet.models.gnn.dynedge"], [57, "module-graphnet.models.gnn.dynedge_jinst"], [58, "module-graphnet.models.gnn.dynedge_kaggle_tito"], [59, "module-graphnet.models.gnn.gnn"], [60, "module-graphnet.models.graphs"], [61, "module-graphnet.models.graphs.edges"], [62, "module-graphnet.models.graphs.edges.edges"], [63, "module-graphnet.models.graphs.graph_definition"], [64, "module-graphnet.models.graphs.graphs"], [65, "module-graphnet.models.graphs.nodes"], [66, "module-graphnet.models.graphs.nodes.nodes"], [67, "module-graphnet.models.model"], [68, "module-graphnet.models.standard_model"], [69, "module-graphnet.models.task"], [70, "module-graphnet.models.task.classification"], [71, "module-graphnet.models.task.reconstruction"], [72, "module-graphnet.models.task.task"], [73, "module-graphnet.models.utils"], [74, "module-graphnet.pisa"], [75, "module-graphnet.pisa.fitting"], [76, "module-graphnet.pisa.plotting"], [77, "module-graphnet.training"], [78, "module-graphnet.training.callbacks"], [79, "module-graphnet.training.labels"], [80, "module-graphnet.training.loss_functions"], [81, "module-graphnet.training.utils"], [82, "module-graphnet.training.weight_fitting"], [83, "module-graphnet.utilities"], [84, "module-graphnet.utilities.argparse"], [85, "module-graphnet.utilities.config"], [86, "module-graphnet.utilities.config.base_config"], [87, "module-graphnet.utilities.config.configurable"], [88, "module-graphnet.utilities.config.dataset_config"], [89, "module-graphnet.utilities.config.model_config"], [90, "module-graphnet.utilities.config.parsing"], [91, "module-graphnet.utilities.config.training_config"], [92, "module-graphnet.utilities.decorators"], [93, "module-graphnet.utilities.filesys"], [94, "module-graphnet.utilities.imports"], [95, "module-graphnet.utilities.logging"], [96, "module-graphnet.utilities.maths"]], "graphnet.constants": [[2, "module-graphnet.constants"]], "graphnet.data": [[3, "module-graphnet.data"]], "deepcore (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.DEEPCORE"]], "deepcore (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.DEEPCORE"]], "features (class in graphnet.data.constants)": [[4, "graphnet.data.constants.FEATURES"]], "icecube86 (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.ICECUBE86"]], "icecube86 (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.ICECUBE86"]], "kaggle (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.KAGGLE"]], "kaggle (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.KAGGLE"]], "prometheus (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.PROMETHEUS"]], "prometheus (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.PROMETHEUS"]], "truth (class in graphnet.data.constants)": [[4, "graphnet.data.constants.TRUTH"]], "upgrade (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.UPGRADE"]], "upgrade (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.UPGRADE"]], "graphnet.data.constants": [[4, "module-graphnet.data.constants"]], "dataconverter (class in graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.DataConverter"]], "fileset (class in graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.FileSet"]], "cache_output_files() (in module graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.cache_output_files"]], "execute() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.execute"]], "file_suffix (graphnet.data.dataconverter.dataconverter property)": [[5, "graphnet.data.dataconverter.DataConverter.file_suffix"]], "gcd_file (graphnet.data.dataconverter.fileset attribute)": [[5, "graphnet.data.dataconverter.FileSet.gcd_file"]], "get_map_function() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.get_map_function"]], "graphnet.data.dataconverter": [[5, "module-graphnet.data.dataconverter"]], "i3_file (graphnet.data.dataconverter.fileset attribute)": [[5, "graphnet.data.dataconverter.FileSet.i3_file"]], "init_global_index() (in module graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.init_global_index"]], "merge_files() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.merge_files"]], "save_data() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.save_data"]], "dataloader (class in graphnet.data.dataloader)": [[6, "graphnet.data.dataloader.DataLoader"]], "collate_fn() (in module graphnet.data.dataloader)": [[6, "graphnet.data.dataloader.collate_fn"]], "do_shuffle() (in module graphnet.data.dataloader)": [[6, "graphnet.data.dataloader.do_shuffle"]], "from_dataset_config() (graphnet.data.dataloader.dataloader class method)": [[6, "graphnet.data.dataloader.DataLoader.from_dataset_config"]], "graphnet.data.dataloader": [[6, "module-graphnet.data.dataloader"]], "graphnet.data.dataset": [[7, "module-graphnet.data.dataset"]], "columnmissingexception": [[8, "graphnet.data.dataset.dataset.ColumnMissingException"]], "dataset (class in graphnet.data.dataset.dataset)": [[8, "graphnet.data.dataset.dataset.Dataset"]], "ensembledataset (class in graphnet.data.dataset.dataset)": [[8, "graphnet.data.dataset.dataset.EnsembleDataset"]], "add_label() (graphnet.data.dataset.dataset.dataset method)": [[8, "graphnet.data.dataset.dataset.Dataset.add_label"]], "concatenate() (graphnet.data.dataset.dataset.dataset class method)": [[8, "graphnet.data.dataset.dataset.Dataset.concatenate"]], "from_config() (graphnet.data.dataset.dataset.dataset class method)": [[8, "graphnet.data.dataset.dataset.Dataset.from_config"]], "graphnet.data.dataset.dataset": [[8, "module-graphnet.data.dataset.dataset"]], "load_module() (in module graphnet.data.dataset.dataset)": [[8, "graphnet.data.dataset.dataset.load_module"]], "parse_graph_definition() (in module graphnet.data.dataset.dataset)": [[8, "graphnet.data.dataset.dataset.parse_graph_definition"]], "path (graphnet.data.dataset.dataset.dataset property)": [[8, "graphnet.data.dataset.dataset.Dataset.path"]], "query_table() (graphnet.data.dataset.dataset.dataset method)": [[8, "graphnet.data.dataset.dataset.Dataset.query_table"]], "truth_table (graphnet.data.dataset.dataset.dataset property)": [[8, "graphnet.data.dataset.dataset.Dataset.truth_table"]], "graphnet.data.dataset.parquet": [[9, "module-graphnet.data.dataset.parquet"]], "parquetdataset (class in graphnet.data.dataset.parquet.parquet_dataset)": [[10, "graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset"]], "graphnet.data.dataset.parquet.parquet_dataset": [[10, "module-graphnet.data.dataset.parquet.parquet_dataset"]], "query_table() (graphnet.data.dataset.parquet.parquet_dataset.parquetdataset method)": [[10, "graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table"]], "graphnet.data.dataset.sqlite": [[11, "module-graphnet.data.dataset.sqlite"]], "sqlitedataset (class in graphnet.data.dataset.sqlite.sqlite_dataset)": [[12, "graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset"]], "graphnet.data.dataset.sqlite.sqlite_dataset": [[12, "module-graphnet.data.dataset.sqlite.sqlite_dataset"]], "query_table() (graphnet.data.dataset.sqlite.sqlite_dataset.sqlitedataset method)": [[12, "graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table"]], "sqlitedatasetperturbed (class in graphnet.data.dataset.sqlite.sqlite_dataset_perturbed)": [[13, "graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed"]], "graphnet.data.dataset.sqlite.sqlite_dataset_perturbed": [[13, "module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed"]], "graphnet.data.extractors": [[14, "module-graphnet.data.extractors"]], "i3extractor (class in graphnet.data.extractors.i3extractor)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor"]], "i3extractorcollection (class in graphnet.data.extractors.i3extractor)": [[15, "graphnet.data.extractors.i3extractor.I3ExtractorCollection"]], "graphnet.data.extractors.i3extractor": [[15, "module-graphnet.data.extractors.i3extractor"]], "name (graphnet.data.extractors.i3extractor.i3extractor property)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor.name"]], "set_files() (graphnet.data.extractors.i3extractor.i3extractor method)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor.set_files"]], "set_files() (graphnet.data.extractors.i3extractor.i3extractorcollection method)": [[15, "graphnet.data.extractors.i3extractor.I3ExtractorCollection.set_files"]], "i3featureextractor (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"]], "i3featureextractoricecube86 (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCube86"]], "i3featureextractoricecubedeepcore (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCubeDeepCore"]], "i3featureextractoricecubeupgrade (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCubeUpgrade"]], "i3pulsenoisetruthflagicecubeupgrade (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3PulseNoiseTruthFlagIceCubeUpgrade"]], "graphnet.data.extractors.i3featureextractor": [[16, "module-graphnet.data.extractors.i3featureextractor"]], "i3genericextractor (class in graphnet.data.extractors.i3genericextractor)": [[17, "graphnet.data.extractors.i3genericextractor.I3GenericExtractor"]], "graphnet.data.extractors.i3genericextractor": [[17, "module-graphnet.data.extractors.i3genericextractor"]], "i3galacticplanehybridrecoextractor (class in graphnet.data.extractors.i3hybridrecoextractor)": [[18, "graphnet.data.extractors.i3hybridrecoextractor.I3GalacticPlaneHybridRecoExtractor"]], "graphnet.data.extractors.i3hybridrecoextractor": [[18, "module-graphnet.data.extractors.i3hybridrecoextractor"]], "i3ntmuonlabelextractor (class in graphnet.data.extractors.i3ntmuonlabelsextractor)": [[19, "graphnet.data.extractors.i3ntmuonlabelsextractor.I3NTMuonLabelExtractor"]], "graphnet.data.extractors.i3ntmuonlabelsextractor": [[19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"]], "i3particleextractor (class in graphnet.data.extractors.i3particleextractor)": [[20, "graphnet.data.extractors.i3particleextractor.I3ParticleExtractor"]], "graphnet.data.extractors.i3particleextractor": [[20, "module-graphnet.data.extractors.i3particleextractor"]], "i3pisaextractor (class in graphnet.data.extractors.i3pisaextractor)": [[21, "graphnet.data.extractors.i3pisaextractor.I3PISAExtractor"]], "graphnet.data.extractors.i3pisaextractor": [[21, "module-graphnet.data.extractors.i3pisaextractor"]], "i3quesoextractor (class in graphnet.data.extractors.i3quesoextractor)": [[22, "graphnet.data.extractors.i3quesoextractor.I3QUESOExtractor"]], "graphnet.data.extractors.i3quesoextractor": [[22, "module-graphnet.data.extractors.i3quesoextractor"]], "i3retroextractor (class in graphnet.data.extractors.i3retroextractor)": [[23, "graphnet.data.extractors.i3retroextractor.I3RetroExtractor"]], "graphnet.data.extractors.i3retroextractor": [[23, "module-graphnet.data.extractors.i3retroextractor"]], "i3splinempeicextractor (class in graphnet.data.extractors.i3splinempeextractor)": [[24, "graphnet.data.extractors.i3splinempeextractor.I3SplineMPEICExtractor"]], "graphnet.data.extractors.i3splinempeextractor": [[24, "module-graphnet.data.extractors.i3splinempeextractor"]], "i3truthextractor (class in graphnet.data.extractors.i3truthextractor)": [[25, "graphnet.data.extractors.i3truthextractor.I3TruthExtractor"]], "graphnet.data.extractors.i3truthextractor": [[25, "module-graphnet.data.extractors.i3truthextractor"]], "i3tumextractor (class in graphnet.data.extractors.i3tumextractor)": [[26, "graphnet.data.extractors.i3tumextractor.I3TUMExtractor"]], "graphnet.data.extractors.i3tumextractor": [[26, "module-graphnet.data.extractors.i3tumextractor"]], "graphnet.data.extractors.utilities": [[27, "module-graphnet.data.extractors.utilities"]], "flatten_nested_dictionary() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.flatten_nested_dictionary"]], "graphnet.data.extractors.utilities.collections": [[28, "module-graphnet.data.extractors.utilities.collections"]], "serialise() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.serialise"]], "transpose_list_of_dicts() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.transpose_list_of_dicts"]], "frame_is_montecarlo() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.frame_is_montecarlo"]], "frame_is_noise() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.frame_is_noise"]], "get_om_keys_and_pulseseries() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.get_om_keys_and_pulseseries"]], "graphnet.data.extractors.utilities.frames": [[29, "module-graphnet.data.extractors.utilities.frames"]], "break_cyclic_recursion() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.break_cyclic_recursion"]], "cast_object_to_pure_python() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.cast_object_to_pure_python"]], "cast_pulse_series_to_pure_python() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.cast_pulse_series_to_pure_python"]], "get_member_variables() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.get_member_variables"]], "graphnet.data.extractors.utilities.types": [[30, "module-graphnet.data.extractors.utilities.types"]], "is_boost_class() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_boost_class"]], "is_boost_enum() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_boost_enum"]], "is_icecube_class() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_icecube_class"]], "is_method() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_method"]], "is_type() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_type"]], "graphnet.data.parquet": [[31, "module-graphnet.data.parquet"]], "parquetdataconverter (class in graphnet.data.parquet.parquet_dataconverter)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter"]], "file_suffix (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter attribute)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.file_suffix"]], "graphnet.data.parquet.parquet_dataconverter": [[32, "module-graphnet.data.parquet.parquet_dataconverter"]], "merge_files() (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter method)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.merge_files"]], "save_data() (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter method)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.save_data"]], "insqlitepipeline (class in graphnet.data.pipeline)": [[33, "graphnet.data.pipeline.InSQLitePipeline"]], "graphnet.data.pipeline": [[33, "module-graphnet.data.pipeline"]], "graphnet.data.sqlite": [[34, "module-graphnet.data.sqlite"]], "sqlitedataconverter (class in graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter"]], "any_pulsemap_is_non_empty() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.any_pulsemap_is_non_empty"]], "construct_dataframe() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.construct_dataframe"]], "file_suffix (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter attribute)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.file_suffix"]], "graphnet.data.sqlite.sqlite_dataconverter": [[35, "module-graphnet.data.sqlite.sqlite_dataconverter"]], "is_mc_tree() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.is_mc_tree"]], "is_pulse_map() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.is_pulse_map"]], "merge_files() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.merge_files"]], "save_data() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.save_data"]], "attach_index() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.attach_index"]], "create_table() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.create_table"]], "create_table_and_save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.create_table_and_save_to_sql"]], "database_exists() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.database_exists"]], "database_table_exists() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.database_table_exists"]], "graphnet.data.sqlite.sqlite_utilities": [[36, "module-graphnet.data.sqlite.sqlite_utilities"]], "run_sql_code() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.run_sql_code"]], "save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.save_to_sql"]], "graphnet.data.utilities": [[37, "module-graphnet.data.utilities"]], "parquettosqliteconverter (class in graphnet.data.utilities.parquet_to_sqlite)": [[38, "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter"]], "graphnet.data.utilities.parquet_to_sqlite": [[38, "module-graphnet.data.utilities.parquet_to_sqlite"]], "run() (graphnet.data.utilities.parquet_to_sqlite.parquettosqliteconverter method)": [[38, "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter.run"]], "graphnet.data.utilities.random": [[39, "module-graphnet.data.utilities.random"]], "pairwise_shuffle() (in module graphnet.data.utilities.random)": [[39, "graphnet.data.utilities.random.pairwise_shuffle"]], "stringselectionresolver (class in graphnet.data.utilities.string_selection_resolver)": [[40, "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver"]], "graphnet.data.utilities.string_selection_resolver": [[40, "module-graphnet.data.utilities.string_selection_resolver"]], "resolve() (graphnet.data.utilities.string_selection_resolver.stringselectionresolver method)": [[40, "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver.resolve"]], "graphnet.deployment": [[41, "module-graphnet.deployment"]], "graphneti3module (class in graphnet.deployment.i3modules.graphnet_module)": [[44, "graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module"]], "i3inferencemodule (class in graphnet.deployment.i3modules.graphnet_module)": [[44, "graphnet.deployment.i3modules.graphnet_module.I3InferenceModule"]], "i3pulsecleanermodule (class in graphnet.deployment.i3modules.graphnet_module)": [[44, "graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule"]], "graphnet.deployment.i3modules.graphnet_module": [[44, "module-graphnet.deployment.i3modules.graphnet_module"]], "graphnet.models": [[45, "module-graphnet.models"]], "attributecoarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.AttributeCoarsening"]], "coarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.Coarsening"]], "customdomcoarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.CustomDOMCoarsening"]], "domandtimewindowcoarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.DOMAndTimeWindowCoarsening"]], "domcoarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.DOMCoarsening"]], "forward() (graphnet.models.coarsening.coarsening method)": [[46, "graphnet.models.coarsening.Coarsening.forward"]], "graphnet.models.coarsening": [[46, "module-graphnet.models.coarsening"]], "reduce_options (graphnet.models.coarsening.coarsening attribute)": [[46, "graphnet.models.coarsening.Coarsening.reduce_options"]], "unbatch_edge_index() (in module graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.unbatch_edge_index"]], "graphnet.models.components": [[47, "module-graphnet.models.components"]], "dynedgeconv (class in graphnet.models.components.layers)": [[48, "graphnet.models.components.layers.DynEdgeConv"]], "dyntrans (class in graphnet.models.components.layers)": [[48, "graphnet.models.components.layers.DynTrans"]], "edgeconvtito (class in graphnet.models.components.layers)": [[48, "graphnet.models.components.layers.EdgeConvTito"]], "forward() (graphnet.models.components.layers.dynedgeconv method)": [[48, "graphnet.models.components.layers.DynEdgeConv.forward"]], "forward() (graphnet.models.components.layers.dyntrans method)": [[48, "graphnet.models.components.layers.DynTrans.forward"]], "forward() (graphnet.models.components.layers.edgeconvtito method)": [[48, "graphnet.models.components.layers.EdgeConvTito.forward"]], "graphnet.models.components.layers": [[48, "module-graphnet.models.components.layers"]], "message() (graphnet.models.components.layers.edgeconvtito method)": [[48, "graphnet.models.components.layers.EdgeConvTito.message"]], "reset_parameters() (graphnet.models.components.layers.edgeconvtito method)": [[48, "graphnet.models.components.layers.EdgeConvTito.reset_parameters"]], "graphnet.models.components.pool": [[49, "module-graphnet.models.components.pool"]], "group_by() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.group_by"]], "group_pulses_to_dom() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.group_pulses_to_dom"]], "group_pulses_to_pmt() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.group_pulses_to_pmt"]], "min_pool() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.min_pool"]], "min_pool_x() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.min_pool_x"]], "std_pool() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.std_pool"]], "std_pool_x() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.std_pool_x"]], "sum_pool() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.sum_pool"]], "sum_pool_and_distribute() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.sum_pool_and_distribute"]], "sum_pool_x() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.sum_pool_x"]], "graphnet.models.detector": [[50, "module-graphnet.models.detector"]], "detector (class in graphnet.models.detector.detector)": [[51, "graphnet.models.detector.detector.Detector"]], "feature_map() (graphnet.models.detector.detector.detector method)": [[51, "graphnet.models.detector.detector.Detector.feature_map"]], "forward() (graphnet.models.detector.detector.detector method)": [[51, "graphnet.models.detector.detector.Detector.forward"]], "graphnet.models.detector.detector": [[51, "module-graphnet.models.detector.detector"]], "icecube86 (class in graphnet.models.detector.icecube)": [[52, "graphnet.models.detector.icecube.IceCube86"]], "icecubedeepcore (class in graphnet.models.detector.icecube)": [[52, "graphnet.models.detector.icecube.IceCubeDeepCore"]], "icecubekaggle (class in graphnet.models.detector.icecube)": [[52, "graphnet.models.detector.icecube.IceCubeKaggle"]], "icecubeupgrade (class in graphnet.models.detector.icecube)": [[52, "graphnet.models.detector.icecube.IceCubeUpgrade"]], "feature_map() (graphnet.models.detector.icecube.icecube86 method)": [[52, "graphnet.models.detector.icecube.IceCube86.feature_map"]], "feature_map() (graphnet.models.detector.icecube.icecubedeepcore method)": [[52, "graphnet.models.detector.icecube.IceCubeDeepCore.feature_map"]], "feature_map() (graphnet.models.detector.icecube.icecubekaggle method)": [[52, "graphnet.models.detector.icecube.IceCubeKaggle.feature_map"]], "feature_map() (graphnet.models.detector.icecube.icecubeupgrade method)": [[52, "graphnet.models.detector.icecube.IceCubeUpgrade.feature_map"]], "graphnet.models.detector.icecube": [[52, "module-graphnet.models.detector.icecube"]], "prometheus (class in graphnet.models.detector.prometheus)": [[53, "graphnet.models.detector.prometheus.Prometheus"]], "feature_map() (graphnet.models.detector.prometheus.prometheus method)": [[53, "graphnet.models.detector.prometheus.Prometheus.feature_map"]], "graphnet.models.detector.prometheus": [[53, "module-graphnet.models.detector.prometheus"]], "graphnet.models.gnn": [[54, "module-graphnet.models.gnn"]], "convnet (class in graphnet.models.gnn.convnet)": [[55, "graphnet.models.gnn.convnet.ConvNet"]], "forward() (graphnet.models.gnn.convnet.convnet method)": [[55, "graphnet.models.gnn.convnet.ConvNet.forward"]], "graphnet.models.gnn.convnet": [[55, "module-graphnet.models.gnn.convnet"]], "dynedge (class in graphnet.models.gnn.dynedge)": [[56, "graphnet.models.gnn.dynedge.DynEdge"]], "forward() (graphnet.models.gnn.dynedge.dynedge method)": [[56, "graphnet.models.gnn.dynedge.DynEdge.forward"]], "graphnet.models.gnn.dynedge": [[56, "module-graphnet.models.gnn.dynedge"]], "dynedgejinst (class in graphnet.models.gnn.dynedge_jinst)": [[57, "graphnet.models.gnn.dynedge_jinst.DynEdgeJINST"]], "forward() (graphnet.models.gnn.dynedge_jinst.dynedgejinst method)": [[57, "graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward"]], "graphnet.models.gnn.dynedge_jinst": [[57, "module-graphnet.models.gnn.dynedge_jinst"]], "dynedgetito (class in graphnet.models.gnn.dynedge_kaggle_tito)": [[58, "graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO"]], "forward() (graphnet.models.gnn.dynedge_kaggle_tito.dynedgetito method)": [[58, "graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward"]], "graphnet.models.gnn.dynedge_kaggle_tito": [[58, "module-graphnet.models.gnn.dynedge_kaggle_tito"]], "gnn (class in graphnet.models.gnn.gnn)": [[59, "graphnet.models.gnn.gnn.GNN"]], "forward() (graphnet.models.gnn.gnn.gnn method)": [[59, "graphnet.models.gnn.gnn.GNN.forward"]], "graphnet.models.gnn.gnn": [[59, "module-graphnet.models.gnn.gnn"]], "nb_inputs (graphnet.models.gnn.gnn.gnn property)": [[59, "graphnet.models.gnn.gnn.GNN.nb_inputs"]], "nb_outputs (graphnet.models.gnn.gnn.gnn property)": [[59, "graphnet.models.gnn.gnn.GNN.nb_outputs"]], "graphnet.models.graphs": [[60, "module-graphnet.models.graphs"]], "graphnet.models.graphs.edges": [[61, "module-graphnet.models.graphs.edges"]], "edgedefinition (class in graphnet.models.graphs.edges.edges)": [[62, "graphnet.models.graphs.edges.edges.EdgeDefinition"]], "euclideanedges (class in graphnet.models.graphs.edges.edges)": [[62, "graphnet.models.graphs.edges.edges.EuclideanEdges"]], "knnedges (class in graphnet.models.graphs.edges.edges)": [[62, "graphnet.models.graphs.edges.edges.KNNEdges"]], "radialedges (class in graphnet.models.graphs.edges.edges)": [[62, "graphnet.models.graphs.edges.edges.RadialEdges"]], "forward() (graphnet.models.graphs.edges.edges.edgedefinition method)": [[62, "graphnet.models.graphs.edges.edges.EdgeDefinition.forward"]], "graphnet.models.graphs.edges.edges": [[62, "module-graphnet.models.graphs.edges.edges"]], "graphdefinition (class in graphnet.models.graphs.graph_definition)": [[63, "graphnet.models.graphs.graph_definition.GraphDefinition"]], "forward() (graphnet.models.graphs.graph_definition.graphdefinition method)": [[63, "graphnet.models.graphs.graph_definition.GraphDefinition.forward"]], "graphnet.models.graphs.graph_definition": [[63, "module-graphnet.models.graphs.graph_definition"]], "knngraph (class in graphnet.models.graphs.graphs)": [[64, "graphnet.models.graphs.graphs.KNNGraph"]], "graphnet.models.graphs.graphs": [[64, "module-graphnet.models.graphs.graphs"]], "graphnet.models.graphs.nodes": [[65, "module-graphnet.models.graphs.nodes"]], "nodedefinition (class in graphnet.models.graphs.nodes.nodes)": [[66, "graphnet.models.graphs.nodes.nodes.NodeDefinition"]], "nodesaspulses (class in graphnet.models.graphs.nodes.nodes)": [[66, "graphnet.models.graphs.nodes.nodes.NodesAsPulses"]], "forward() (graphnet.models.graphs.nodes.nodes.nodedefinition method)": [[66, "graphnet.models.graphs.nodes.nodes.NodeDefinition.forward"]], "graphnet.models.graphs.nodes.nodes": [[66, "module-graphnet.models.graphs.nodes.nodes"]], "nb_outputs (graphnet.models.graphs.nodes.nodes.nodedefinition property)": [[66, "graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs"]], "set_number_of_inputs() (graphnet.models.graphs.nodes.nodes.nodedefinition method)": [[66, "graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs"]], "model (class in graphnet.models.model)": [[67, "graphnet.models.model.Model"]], "fit() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.fit"]], "forward() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.forward"]], "from_config() (graphnet.models.model.model class method)": [[67, "graphnet.models.model.Model.from_config"]], "graphnet.models.model": [[67, "module-graphnet.models.model"]], "load() (graphnet.models.model.model class method)": [[67, "graphnet.models.model.Model.load"]], "load_state_dict() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.load_state_dict"]], "predict() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.predict"]], "predict_as_dataframe() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.predict_as_dataframe"]], "save() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.save"]], "save_state_dict() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.save_state_dict"]], "standardmodel (class in graphnet.models.standard_model)": [[68, "graphnet.models.standard_model.StandardModel"]], "compute_loss() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.compute_loss"]], "configure_optimizers() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.configure_optimizers"]], "forward() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.forward"]], "graphnet.models.standard_model": [[68, "module-graphnet.models.standard_model"]], "inference() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.inference"]], "predict() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.predict"]], "predict_as_dataframe() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.predict_as_dataframe"]], "prediction_labels (graphnet.models.standard_model.standardmodel property)": [[68, "graphnet.models.standard_model.StandardModel.prediction_labels"]], "shared_step() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.shared_step"]], "target_labels (graphnet.models.standard_model.standardmodel property)": [[68, "graphnet.models.standard_model.StandardModel.target_labels"]], "train() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.train"]], "training_step() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.training_step"]], "validation_step() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.validation_step"]], "graphnet.models.task": [[69, "module-graphnet.models.task"]], "binaryclassificationtask (class in graphnet.models.task.classification)": [[70, "graphnet.models.task.classification.BinaryClassificationTask"]], "binaryclassificationtasklogits (class in graphnet.models.task.classification)": [[70, "graphnet.models.task.classification.BinaryClassificationTaskLogits"]], "multiclassclassificationtask (class in graphnet.models.task.classification)": [[70, "graphnet.models.task.classification.MulticlassClassificationTask"]], "default_prediction_labels (graphnet.models.task.classification.binaryclassificationtask attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.classification.binaryclassificationtasklogits attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels"]], "default_target_labels (graphnet.models.task.classification.binaryclassificationtask attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTask.default_target_labels"]], "default_target_labels (graphnet.models.task.classification.binaryclassificationtasklogits attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels"]], "graphnet.models.task.classification": [[70, "module-graphnet.models.task.classification"]], "nb_inputs (graphnet.models.task.classification.binaryclassificationtask attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTask.nb_inputs"]], "nb_inputs (graphnet.models.task.classification.binaryclassificationtasklogits attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs"]], "azimuthreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstruction"]], "azimuthreconstructionwithkappa (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa"]], "directionreconstructionwithkappa (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa"]], "energyreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.EnergyReconstruction"]], "energyreconstructionwithpower (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithPower"]], "energyreconstructionwithuncertainty (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty"]], "inelasticityreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.InelasticityReconstruction"]], "positionreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.PositionReconstruction"]], "timereconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.TimeReconstruction"]], "vertexreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.VertexReconstruction"]], "zenithreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.ZenithReconstruction"]], "zenithreconstructionwithkappa (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa"]], "default_prediction_labels (graphnet.models.task.reconstruction.azimuthreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.azimuthreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.directionreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.energyreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.energyreconstructionwithpower attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.energyreconstructionwithuncertainty attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.inelasticityreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.positionreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.timereconstruction attribute)": [[71, "graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.vertexreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.zenithreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.zenithreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels"]], "default_target_labels (graphnet.models.task.reconstruction.azimuthreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.azimuthreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.directionreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.energyreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.energyreconstructionwithpower attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.energyreconstructionwithuncertainty attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.inelasticityreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.positionreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.timereconstruction attribute)": [[71, "graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.vertexreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.zenithreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.zenithreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels"]], "graphnet.models.task.reconstruction": [[71, "module-graphnet.models.task.reconstruction"]], "nb_inputs (graphnet.models.task.reconstruction.azimuthreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.azimuthreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.directionreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.energyreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.energyreconstructionwithpower attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.energyreconstructionwithuncertainty attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.inelasticityreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.positionreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.timereconstruction attribute)": [[71, "graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.vertexreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.zenithreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.zenithreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs"]], "identitytask (class in graphnet.models.task.task)": [[72, "graphnet.models.task.task.IdentityTask"]], "task (class in graphnet.models.task.task)": [[72, "graphnet.models.task.task.Task"]], "compute_loss() (graphnet.models.task.task.task method)": [[72, "graphnet.models.task.task.Task.compute_loss"]], "default_prediction_labels (graphnet.models.task.task.identitytask property)": [[72, "graphnet.models.task.task.IdentityTask.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.task.task property)": [[72, "graphnet.models.task.task.Task.default_prediction_labels"]], "default_target_labels (graphnet.models.task.task.identitytask property)": [[72, "graphnet.models.task.task.IdentityTask.default_target_labels"]], "default_target_labels (graphnet.models.task.task.task property)": [[72, "graphnet.models.task.task.Task.default_target_labels"]], "forward() (graphnet.models.task.task.task method)": [[72, "graphnet.models.task.task.Task.forward"]], "graphnet.models.task.task": [[72, "module-graphnet.models.task.task"]], "inference() (graphnet.models.task.task.task method)": [[72, "graphnet.models.task.task.Task.inference"]], "nb_inputs (graphnet.models.task.task.identitytask property)": [[72, "graphnet.models.task.task.IdentityTask.nb_inputs"]], "nb_inputs (graphnet.models.task.task.task property)": [[72, "graphnet.models.task.task.Task.nb_inputs"]], "train_eval() (graphnet.models.task.task.task method)": [[72, "graphnet.models.task.task.Task.train_eval"]], "calculate_distance_matrix() (in module graphnet.models.utils)": [[73, "graphnet.models.utils.calculate_distance_matrix"]], "calculate_xyzt_homophily() (in module graphnet.models.utils)": [[73, "graphnet.models.utils.calculate_xyzt_homophily"]], "graphnet.models.utils": [[73, "module-graphnet.models.utils"]], "knn_graph_batch() (in module graphnet.models.utils)": [[73, "graphnet.models.utils.knn_graph_batch"]], "graphnet.pisa": [[74, "module-graphnet.pisa"]], "contourfitter (class in graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.ContourFitter"]], "weightfitter (class in graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.WeightFitter"]], "config_updater() (in module graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.config_updater"]], "fit_1d_contour() (graphnet.pisa.fitting.contourfitter method)": [[75, "graphnet.pisa.fitting.ContourFitter.fit_1d_contour"]], "fit_2d_contour() (graphnet.pisa.fitting.contourfitter method)": [[75, "graphnet.pisa.fitting.ContourFitter.fit_2d_contour"]], "fit_weights() (graphnet.pisa.fitting.weightfitter method)": [[75, "graphnet.pisa.fitting.WeightFitter.fit_weights"]], "graphnet.pisa.fitting": [[75, "module-graphnet.pisa.fitting"]], "graphnet.pisa.plotting": [[76, "module-graphnet.pisa.plotting"]], "plot_1d_contour() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.plot_1D_contour"]], "plot_2d_contour() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.plot_2D_contour"]], "read_entry() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.read_entry"]], "graphnet.training": [[77, "module-graphnet.training"]], "piecewiselinearlr (class in graphnet.training.callbacks)": [[78, "graphnet.training.callbacks.PiecewiseLinearLR"]], "progressbar (class in graphnet.training.callbacks)": [[78, "graphnet.training.callbacks.ProgressBar"]], "get_lr() (graphnet.training.callbacks.piecewiselinearlr method)": [[78, "graphnet.training.callbacks.PiecewiseLinearLR.get_lr"]], "get_metrics() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.get_metrics"]], "graphnet.training.callbacks": [[78, "module-graphnet.training.callbacks"]], "init_predict_tqdm() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.init_predict_tqdm"]], "init_test_tqdm() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.init_test_tqdm"]], "init_train_tqdm() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.init_train_tqdm"]], "init_validation_tqdm() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.init_validation_tqdm"]], "on_train_epoch_end() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.on_train_epoch_end"]], "on_train_epoch_start() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.on_train_epoch_start"]], "direction (class in graphnet.training.labels)": [[79, "graphnet.training.labels.Direction"]], "label (class in graphnet.training.labels)": [[79, "graphnet.training.labels.Label"]], "graphnet.training.labels": [[79, "module-graphnet.training.labels"]], "key (graphnet.training.labels.label property)": [[79, "graphnet.training.labels.Label.key"]], "binarycrossentropyloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.BinaryCrossEntropyLoss"]], "crossentropyloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.CrossEntropyLoss"]], "euclideandistanceloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.EuclideanDistanceLoss"]], "logcmk (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.LogCMK"]], "logcoshloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.LogCoshLoss"]], "lossfunction (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.LossFunction"]], "mseloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.MSELoss"]], "rmseloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.RMSELoss"]], "vonmisesfisher2dloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.VonMisesFisher2DLoss"]], "vonmisesfisher3dloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.VonMisesFisher3DLoss"]], "vonmisesfisherloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.VonMisesFisherLoss"]], "backward() (graphnet.training.loss_functions.logcmk static method)": [[80, "graphnet.training.loss_functions.LogCMK.backward"]], "forward() (graphnet.training.loss_functions.logcmk static method)": [[80, "graphnet.training.loss_functions.LogCMK.forward"]], "forward() (graphnet.training.loss_functions.lossfunction method)": [[80, "graphnet.training.loss_functions.LossFunction.forward"]], "graphnet.training.loss_functions": [[80, "module-graphnet.training.loss_functions"]], "log_cmk() (graphnet.training.loss_functions.vonmisesfisherloss class method)": [[80, "graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk"]], "log_cmk_approx() (graphnet.training.loss_functions.vonmisesfisherloss class method)": [[80, "graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx"]], "log_cmk_exact() (graphnet.training.loss_functions.vonmisesfisherloss class method)": [[80, "graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact"]], "collate_fn() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.collate_fn"]], "get_predictions() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.get_predictions"]], "graphnet.training.utils": [[81, "module-graphnet.training.utils"]], "make_dataloader() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.make_dataloader"]], "make_train_validation_dataloader() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.make_train_validation_dataloader"]], "save_results() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.save_results"]], "bjoernlow (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.BjoernLow"]], "uniform (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.Uniform"]], "weightfitter (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.WeightFitter"]], "fit() (graphnet.training.weight_fitting.weightfitter method)": [[82, "graphnet.training.weight_fitting.WeightFitter.fit"]], "graphnet.training.weight_fitting": [[82, "module-graphnet.training.weight_fitting"]], "graphnet.utilities": [[83, "module-graphnet.utilities"]], "argumentparser (class in graphnet.utilities.argparse)": [[84, "graphnet.utilities.argparse.ArgumentParser"]], "options (class in graphnet.utilities.argparse)": [[84, "graphnet.utilities.argparse.Options"]], "contains() (graphnet.utilities.argparse.options method)": [[84, "graphnet.utilities.argparse.Options.contains"]], "graphnet.utilities.argparse": [[84, "module-graphnet.utilities.argparse"]], "pop_default() (graphnet.utilities.argparse.options method)": [[84, "graphnet.utilities.argparse.Options.pop_default"]], "standard_arguments (graphnet.utilities.argparse.argumentparser attribute)": [[84, "graphnet.utilities.argparse.ArgumentParser.standard_arguments"]], "with_standard_arguments() (graphnet.utilities.argparse.argumentparser method)": [[84, "graphnet.utilities.argparse.ArgumentParser.with_standard_arguments"]], "graphnet.utilities.config": [[85, "module-graphnet.utilities.config"]], "baseconfig (class in graphnet.utilities.config.base_config)": [[86, "graphnet.utilities.config.base_config.BaseConfig"]], "as_dict() (graphnet.utilities.config.base_config.baseconfig method)": [[86, "graphnet.utilities.config.base_config.BaseConfig.as_dict"]], "dump() (graphnet.utilities.config.base_config.baseconfig method)": [[86, "graphnet.utilities.config.base_config.BaseConfig.dump"]], "get_all_argument_values() (in module graphnet.utilities.config.base_config)": [[86, "graphnet.utilities.config.base_config.get_all_argument_values"]], "graphnet.utilities.config.base_config": [[86, "module-graphnet.utilities.config.base_config"]], "load() (graphnet.utilities.config.base_config.baseconfig class method)": [[86, "graphnet.utilities.config.base_config.BaseConfig.load"]], "model_config (graphnet.utilities.config.base_config.baseconfig attribute)": [[86, "graphnet.utilities.config.base_config.BaseConfig.model_config"]], "model_fields (graphnet.utilities.config.base_config.baseconfig attribute)": [[86, "graphnet.utilities.config.base_config.BaseConfig.model_fields"]], "configurable (class in graphnet.utilities.config.configurable)": [[87, "graphnet.utilities.config.configurable.Configurable"]], "config (graphnet.utilities.config.configurable.configurable property)": [[87, "graphnet.utilities.config.configurable.Configurable.config"]], "from_config() (graphnet.utilities.config.configurable.configurable class method)": [[87, "graphnet.utilities.config.configurable.Configurable.from_config"]], "graphnet.utilities.config.configurable": [[87, "module-graphnet.utilities.config.configurable"]], "save_config() (graphnet.utilities.config.configurable.configurable method)": [[87, "graphnet.utilities.config.configurable.Configurable.save_config"]], "datasetconfig (class in graphnet.utilities.config.dataset_config)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig"]], "as_dict() (graphnet.utilities.config.dataset_config.datasetconfig method)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.as_dict"]], "features (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.features"]], "graph_definition (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition"]], "graphnet.utilities.config.dataset_config": [[88, "module-graphnet.utilities.config.dataset_config"]], "index_column (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.index_column"]], "loss_weight_column (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column"]], "loss_weight_default_value (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value"]], "loss_weight_table (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table"]], "model_config (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.model_config"]], "model_fields (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.model_fields"]], "node_truth (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.node_truth"]], "node_truth_table (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table"]], "path (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.path"]], "pulsemaps (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps"]], "save_dataset_config() (in module graphnet.utilities.config.dataset_config)": [[88, "graphnet.utilities.config.dataset_config.save_dataset_config"]], "seed (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.seed"]], "selection (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.selection"]], "string_selection (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.string_selection"]], "truth (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.truth"]], "truth_table (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.truth_table"]], "modelconfig (class in graphnet.utilities.config.model_config)": [[89, "graphnet.utilities.config.model_config.ModelConfig"]], "arguments (graphnet.utilities.config.model_config.modelconfig attribute)": [[89, "graphnet.utilities.config.model_config.ModelConfig.arguments"]], "as_dict() (graphnet.utilities.config.model_config.modelconfig method)": [[89, "graphnet.utilities.config.model_config.ModelConfig.as_dict"]], "class_name (graphnet.utilities.config.model_config.modelconfig attribute)": [[89, "graphnet.utilities.config.model_config.ModelConfig.class_name"]], "graphnet.utilities.config.model_config": [[89, "module-graphnet.utilities.config.model_config"]], "model_config (graphnet.utilities.config.model_config.modelconfig attribute)": [[89, "graphnet.utilities.config.model_config.ModelConfig.model_config"]], "model_fields (graphnet.utilities.config.model_config.modelconfig attribute)": [[89, "graphnet.utilities.config.model_config.ModelConfig.model_fields"]], "save_model_config() (in module graphnet.utilities.config.model_config)": [[89, "graphnet.utilities.config.model_config.save_model_config"]], "get_all_grapnet_classes() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.get_all_grapnet_classes"]], "get_graphnet_classes() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.get_graphnet_classes"]], "graphnet.utilities.config.parsing": [[90, "module-graphnet.utilities.config.parsing"]], "is_graphnet_class() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.is_graphnet_class"]], "is_graphnet_module() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.is_graphnet_module"]], "list_all_submodules() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.list_all_submodules"]], "traverse_and_apply() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.traverse_and_apply"]], "trainingconfig (class in graphnet.utilities.config.training_config)": [[91, "graphnet.utilities.config.training_config.TrainingConfig"]], "dataloader (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.dataloader"]], "early_stopping_patience (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience"]], "fit (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.fit"]], "graphnet.utilities.config.training_config": [[91, "module-graphnet.utilities.config.training_config"]], "model_config (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.model_config"]], "model_fields (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.model_fields"]], "target (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.target"]], "graphnet.utilities.decorators": [[92, "module-graphnet.utilities.decorators"]], "find_i3_files() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.find_i3_files"]], "graphnet.utilities.filesys": [[93, "module-graphnet.utilities.filesys"]], "has_extension() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.has_extension"]], "is_gcd_file() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.is_gcd_file"]], "is_i3_file() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.is_i3_file"]], "graphnet.utilities.imports": [[94, "module-graphnet.utilities.imports"]], "has_icecube_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_icecube_package"]], "has_pisa_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_pisa_package"]], "has_torch_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_torch_package"]], "requires_icecube() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.requires_icecube"]], "logger (class in graphnet.utilities.logging)": [[95, "graphnet.utilities.logging.Logger"]], "repeatfilter (class in graphnet.utilities.logging)": [[95, "graphnet.utilities.logging.RepeatFilter"]], "critical() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.critical"]], "debug() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.debug"]], "error() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.error"]], "file_handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.file_handlers"]], "filter() (graphnet.utilities.logging.repeatfilter method)": [[95, "graphnet.utilities.logging.RepeatFilter.filter"]], "graphnet.utilities.logging": [[95, "module-graphnet.utilities.logging"]], "handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.handlers"]], "info() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.info"]], "nb_repeats_allowed (graphnet.utilities.logging.repeatfilter attribute)": [[95, "graphnet.utilities.logging.RepeatFilter.nb_repeats_allowed"]], "setlevel() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.setLevel"]], "stream_handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.stream_handlers"]], "warning() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.warning"]], "warning_once() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.warning_once"]], "eps_like() (in module graphnet.utilities.maths)": [[96, "graphnet.utilities.maths.eps_like"]], "graphnet.utilities.maths": [[96, "module-graphnet.utilities.maths"]]}})
\ No newline at end of file
diff --git a/sitemap.xml b/sitemap.xml
index 8ca584144..1332adef5 100644
--- a/sitemap.xml
+++ b/sitemap.xml
@@ -1 +1 @@
-<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"><url><loc>https://graphnet-team.github.io/graphnetabout.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.constants.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.constants.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataloader.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.parquet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.parquet.parquet_dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.sqlite_dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3extractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3featureextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3genericextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3hybridrecoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3ntmuonlabelsextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3particleextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3pisaextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3quesoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3retroextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3splinempeextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3truthextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3tumextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.collections.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.frames.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.types.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.parquet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.parquet.parquet_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.pipeline.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.sqlite_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.sqlite_utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.parquet_to_sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.random.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.string_selection_resolver.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.deployer.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.graphnet_module.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.coarsening.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.components.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.components.layers.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.components.pool.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.detector.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.icecube.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.prometheus.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.convnet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge_jinst.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge_kaggle_tito.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.gnn.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.edges.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.edges.edges.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.graph_definition.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.graphs.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.nodes.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.nodes.nodes.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.model.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.standard_model.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.classification.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.reconstruction.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.task.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.utils.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.pisa.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.pisa.fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.pisa.plotting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.callbacks.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.labels.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.loss_functions.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.utils.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.weight_fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.argparse.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.base_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.configurable.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.dataset_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.model_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.parsing.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.training_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.decorators.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.filesys.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.imports.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.logging.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.maths.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/modules.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetcontribute.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetindex.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetinstall.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetgenindex.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetpy-modindex.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/constants.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3extractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3featureextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3genericextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3hybridrecoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3particleextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3pisaextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3quesoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3retroextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3splinempeextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3truthextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3tumextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/collections.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/frames.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/types.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/parquet/parquet_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/sqlite/sqlite_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/sqlite/sqlite_utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/parquet_to_sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/random.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/string_selection_resolver.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/pisa/fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/pisa/plotting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/training/weight_fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/argparse.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/filesys.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/imports.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/logging.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/index.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetsearch.html</loc></url></urlset>
\ No newline at end of file
+<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"><url><loc>https://graphnet-team.github.io/graphnetabout.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.constants.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.constants.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataloader.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.parquet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.parquet.parquet_dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.sqlite_dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3extractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3featureextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3genericextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3hybridrecoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3ntmuonlabelsextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3particleextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3pisaextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3quesoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3retroextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3splinempeextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3truthextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3tumextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.collections.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.frames.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.types.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.parquet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.parquet.parquet_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.pipeline.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.sqlite_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.sqlite_utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.parquet_to_sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.random.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.string_selection_resolver.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.deployer.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.graphnet_module.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.coarsening.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.components.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.components.layers.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.components.pool.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.detector.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.icecube.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.prometheus.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.convnet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge_jinst.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge_kaggle_tito.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.gnn.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.edges.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.edges.edges.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.graph_definition.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.graphs.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.nodes.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.nodes.nodes.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.model.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.standard_model.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.classification.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.reconstruction.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.task.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.utils.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.pisa.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.pisa.fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.pisa.plotting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.callbacks.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.labels.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.loss_functions.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.utils.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.weight_fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.argparse.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.base_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.configurable.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.dataset_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.model_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.parsing.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.training_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.decorators.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.filesys.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.imports.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.logging.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.maths.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/modules.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetcontribute.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetindex.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetinstall.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetgenindex.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetpy-modindex.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/constants.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/dataloader.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/dataset/dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/dataset/parquet/parquet_dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3extractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3featureextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3genericextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3hybridrecoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3particleextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3pisaextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3quesoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3retroextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3splinempeextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3truthextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3tumextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/collections.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/frames.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/types.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/parquet/parquet_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/pipeline.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/sqlite/sqlite_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/sqlite/sqlite_utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/parquet_to_sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/random.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/string_selection_resolver.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/deployment/i3modules/graphnet_module.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/coarsening.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/components/layers.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/components/pool.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/detector/detector.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/detector/icecube.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/detector/prometheus.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/convnet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/dynedge.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/dynedge_jinst.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/dynedge_kaggle_tito.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/gnn.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/graphs/edges/edges.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/graphs/graph_definition.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/graphs/graphs.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/graphs/nodes/nodes.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/model.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/standard_model.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/task/classification.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/task/reconstruction.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/task/task.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/utils.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/pisa/fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/pisa/plotting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/training/callbacks.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/training/labels.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/training/loss_functions.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/training/utils.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/training/weight_fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/argparse.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/base_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/configurable.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/dataset_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/model_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/parsing.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/training_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/filesys.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/imports.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/logging.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/maths.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/index.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetsearch.html</loc></url></urlset>
\ No newline at end of file