From 64f5b9be67cf67eaabeb4283a018b7df48f242a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= <andreas.sogaard@gmail.com> Date: Thu, 14 Sep 2023 09:02:04 +0000 Subject: [PATCH] =?UTF-8?q?Deploying=20to=20gh-pages=20from=20@=20graphnet?= =?UTF-8?q?-team/graphnet@6edf835e0f6c1e044d853c5d8b74a5c5ea50f99d=20?= =?UTF-8?q?=F0=9F=9A=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- _modules/graphnet/data/constants.html | 2 +- _modules/graphnet/data/dataconverter.html | 2 +- _modules/graphnet/data/dataloader.html | 457 ++++++ _modules/graphnet/data/dataset/dataset.html | 1084 ++++++++++++++ .../data/dataset/parquet/parquet_dataset.html | 500 +++++++ .../data/dataset/sqlite/sqlite_dataset.html | 515 +++++++ .../sqlite/sqlite_dataset_perturbed.html | 515 +++++++ .../graphnet/data/extractors/i3extractor.html | 2 +- .../data/extractors/i3featureextractor.html | 2 +- .../data/extractors/i3genericextractor.html | 2 +- .../extractors/i3hybridrecoextractor.html | 2 +- .../extractors/i3ntmuonlabelsextractor.html | 2 +- .../data/extractors/i3particleextractor.html | 2 +- .../data/extractors/i3pisaextractor.html | 2 +- .../data/extractors/i3quesoextractor.html | 2 +- .../data/extractors/i3retroextractor.html | 2 +- .../data/extractors/i3splinempeextractor.html | 2 +- .../data/extractors/i3truthextractor.html | 2 +- .../data/extractors/i3tumextractor.html | 2 +- .../extractors/utilities/collections.html | 2 +- .../data/extractors/utilities/frames.html | 2 +- .../data/extractors/utilities/types.html | 2 +- .../data/parquet/parquet_dataconverter.html | 2 +- _modules/graphnet/data/pipeline.html | 593 ++++++++ .../data/sqlite/sqlite_dataconverter.html | 2 +- .../data/sqlite/sqlite_utilities.html | 2 +- .../data/utilities/parquet_to_sqlite.html | 2 +- _modules/graphnet/data/utilities/random.html | 2 +- .../utilities/string_selection_resolver.html | 2 +- .../deployment/i3modules/graphnet_module.html | 817 +++++++++++ _modules/graphnet/models/coarsening.html | 711 +++++++++ .../graphnet/models/components/layers.html | 579 ++++++++ _modules/graphnet/models/components/pool.html | 656 +++++++++ .../graphnet/models/detector/detector.html | 421 ++++++ .../graphnet/models/detector/icecube.html | 528 +++++++ .../graphnet/models/detector/prometheus.html | 395 +++++ _modules/graphnet/models/gnn/convnet.html | 486 +++++++ _modules/graphnet/models/gnn/dynedge.html | 693 +++++++++ .../graphnet/models/gnn/dynedge_jinst.html | 521 +++++++ .../models/gnn/dynedge_kaggle_tito.html | 618 ++++++++ _modules/graphnet/models/gnn/gnn.html | 403 +++++ .../graphnet/models/graphs/edges/edges.html | 563 +++++++ .../models/graphs/graph_definition.html | 634 ++++++++ _modules/graphnet/models/graphs/graphs.html | 410 ++++++ .../graphnet/models/graphs/nodes/nodes.html | 444 ++++++ _modules/graphnet/models/model.html | 726 ++++++++++ _modules/graphnet/models/standard_model.html | 604 ++++++++ .../graphnet/models/task/classification.html | 411 ++++++ .../graphnet/models/task/reconstruction.html | 609 ++++++++ _modules/graphnet/models/task/task.html | 688 +++++++++ _modules/graphnet/models/utils.html | 430 ++++++ _modules/graphnet/pisa/fitting.html | 2 +- _modules/graphnet/pisa/plotting.html | 2 +- _modules/graphnet/training/callbacks.html | 544 +++++++ _modules/graphnet/training/labels.html | 436 ++++++ .../graphnet/training/loss_functions.html | 859 +++++++++++ _modules/graphnet/training/utils.html | 656 +++++++++ .../graphnet/training/weight_fitting.html | 2 +- _modules/graphnet/utilities/argparse.html | 2 +- .../utilities/config/base_config.html | 449 ++++++ .../utilities/config/configurable.html | 408 ++++++ .../utilities/config/dataset_config.html | 585 ++++++++ .../utilities/config/model_config.html | 654 +++++++++ .../graphnet/utilities/config/parsing.html | 475 ++++++ .../utilities/config/training_config.html | 378 +++++ _modules/graphnet/utilities/filesys.html | 2 +- _modules/graphnet/utilities/imports.html | 2 +- _modules/graphnet/utilities/logging.html | 2 +- _modules/graphnet/utilities/maths.html | 371 +++++ _modules/index.html | 41 +- about.html | 2 +- api/graphnet.constants.html | 2 +- api/graphnet.data.constants.html | 2 +- api/graphnet.data.dataconverter.html | 2 +- api/graphnet.data.dataloader.html | 129 +- api/graphnet.data.dataset.dataset.html | 330 ++++- api/graphnet.data.dataset.html | 16 +- api/graphnet.data.dataset.parquet.html | 12 +- ....data.dataset.parquet.parquet_dataset.html | 116 +- api/graphnet.data.dataset.sqlite.html | 17 +- ...et.data.dataset.sqlite.sqlite_dataset.html | 116 +- ...taset.sqlite.sqlite_dataset_perturbed.html | 84 +- api/graphnet.data.extractors.html | 2 +- api/graphnet.data.extractors.i3extractor.html | 2 +- ...et.data.extractors.i3featureextractor.html | 2 +- ...et.data.extractors.i3genericextractor.html | 2 +- ...data.extractors.i3hybridrecoextractor.html | 2 +- ...ta.extractors.i3ntmuonlabelsextractor.html | 2 +- ...t.data.extractors.i3particleextractor.html | 2 +- ...phnet.data.extractors.i3pisaextractor.html | 2 +- ...hnet.data.extractors.i3quesoextractor.html | 2 +- ...hnet.data.extractors.i3retroextractor.html | 2 +- ....data.extractors.i3splinempeextractor.html | 2 +- ...hnet.data.extractors.i3truthextractor.html | 2 +- ...aphnet.data.extractors.i3tumextractor.html | 2 +- ...data.extractors.utilities.collections.html | 2 +- ...hnet.data.extractors.utilities.frames.html | 2 +- api/graphnet.data.extractors.utilities.html | 2 +- ...phnet.data.extractors.utilities.types.html | 2 +- api/graphnet.data.html | 14 +- api/graphnet.data.parquet.html | 2 +- ...et.data.parquet.parquet_dataconverter.html | 2 +- api/graphnet.data.pipeline.html | 55 +- api/graphnet.data.sqlite.html | 2 +- ...hnet.data.sqlite.sqlite_dataconverter.html | 2 +- ...graphnet.data.sqlite.sqlite_utilities.html | 2 +- api/graphnet.data.utilities.html | 2 +- ...hnet.data.utilities.parquet_to_sqlite.html | 2 +- api/graphnet.data.utilities.random.html | 2 +- ...a.utilities.string_selection_resolver.html | 4 +- api/graphnet.deployment.html | 2 +- ...raphnet.deployment.i3modules.deployer.html | 2 +- ....deployment.i3modules.graphnet_module.html | 135 +- api/graphnet.deployment.i3modules.html | 9 +- api/graphnet.html | 2 +- api/graphnet.models.coarsening.html | 226 ++- api/graphnet.models.components.html | 28 +- api/graphnet.models.components.layers.html | 253 +++- api/graphnet.models.components.pool.html | 339 ++++- api/graphnet.models.detector.detector.html | 89 +- api/graphnet.models.detector.html | 25 +- api/graphnet.models.detector.icecube.html | 197 ++- api/graphnet.models.detector.prometheus.html | 62 +- api/graphnet.models.gnn.convnet.html | 76 +- api/graphnet.models.gnn.dynedge.html | 98 +- api/graphnet.models.gnn.dynedge_jinst.html | 73 +- ...aphnet.models.gnn.dynedge_kaggle_tito.html | 83 +- api/graphnet.models.gnn.gnn.html | 103 +- api/graphnet.models.gnn.html | 32 +- api/graphnet.models.graphs.edges.edges.html | 169 ++- api/graphnet.models.graphs.edges.html | 18 +- ...aphnet.models.graphs.graph_definition.html | 93 +- api/graphnet.models.graphs.graphs.html | 50 +- api/graphnet.models.graphs.html | 20 +- api/graphnet.models.graphs.nodes.html | 16 +- api/graphnet.models.graphs.nodes.nodes.html | 132 +- api/graphnet.models.html | 39 +- api/graphnet.models.model.html | 305 +++- api/graphnet.models.standard_model.html | 348 ++++- api/graphnet.models.task.classification.html | 262 +++- api/graphnet.models.task.html | 36 +- api/graphnet.models.task.reconstruction.html | 1290 ++++++++++++++++- api/graphnet.models.task.task.html | 302 +++- api/graphnet.models.utils.html | 110 +- api/graphnet.pisa.fitting.html | 2 +- api/graphnet.pisa.html | 2 +- api/graphnet.pisa.plotting.html | 2 +- api/graphnet.training.callbacks.html | 277 +++- api/graphnet.training.html | 38 +- api/graphnet.training.labels.html | 93 +- api/graphnet.training.loss_functions.html | 488 ++++++- api/graphnet.training.utils.html | 191 ++- api/graphnet.training.weight_fitting.html | 2 +- api/graphnet.utilities.argparse.html | 2 +- ...graphnet.utilities.config.base_config.html | 177 ++- ...raphnet.utilities.config.configurable.html | 105 +- ...phnet.utilities.config.dataset_config.html | 438 +++++- api/graphnet.utilities.config.html | 45 +- ...raphnet.utilities.config.model_config.html | 177 ++- api/graphnet.utilities.config.parsing.html | 165 ++- ...hnet.utilities.config.training_config.html | 147 +- api/graphnet.utilities.decorators.html | 2 +- api/graphnet.utilities.filesys.html | 2 +- api/graphnet.utilities.html | 7 +- api/graphnet.utilities.imports.html | 2 +- api/graphnet.utilities.logging.html | 2 +- api/graphnet.utilities.maths.html | 41 +- api/modules.html | 2 +- contribute.html | 2 +- genindex.html | 1174 ++++++++++++++- index.html | 2 +- install.html | 2 +- objects.inv | Bin 3520 -> 6210 bytes py-modindex.html | 257 +++- search.html | 2 +- searchindex.js | 2 +- sitemap.xml | 2 +- 177 files changed, 31425 insertions(+), 329 deletions(-) create mode 100644 _modules/graphnet/data/dataloader.html create mode 100644 _modules/graphnet/data/dataset/dataset.html create mode 100644 _modules/graphnet/data/dataset/parquet/parquet_dataset.html create mode 100644 _modules/graphnet/data/dataset/sqlite/sqlite_dataset.html create mode 100644 _modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html create mode 100644 _modules/graphnet/data/pipeline.html create mode 100644 _modules/graphnet/deployment/i3modules/graphnet_module.html create mode 100644 _modules/graphnet/models/coarsening.html create mode 100644 _modules/graphnet/models/components/layers.html create mode 100644 _modules/graphnet/models/components/pool.html create mode 100644 _modules/graphnet/models/detector/detector.html create mode 100644 _modules/graphnet/models/detector/icecube.html create mode 100644 _modules/graphnet/models/detector/prometheus.html create mode 100644 _modules/graphnet/models/gnn/convnet.html create mode 100644 _modules/graphnet/models/gnn/dynedge.html create mode 100644 _modules/graphnet/models/gnn/dynedge_jinst.html create mode 100644 _modules/graphnet/models/gnn/dynedge_kaggle_tito.html create mode 100644 _modules/graphnet/models/gnn/gnn.html create mode 100644 _modules/graphnet/models/graphs/edges/edges.html create mode 100644 _modules/graphnet/models/graphs/graph_definition.html create mode 100644 _modules/graphnet/models/graphs/graphs.html create mode 100644 _modules/graphnet/models/graphs/nodes/nodes.html create mode 100644 _modules/graphnet/models/model.html create mode 100644 _modules/graphnet/models/standard_model.html create mode 100644 _modules/graphnet/models/task/classification.html create mode 100644 _modules/graphnet/models/task/reconstruction.html create mode 100644 _modules/graphnet/models/task/task.html create mode 100644 _modules/graphnet/models/utils.html create mode 100644 _modules/graphnet/training/callbacks.html create mode 100644 _modules/graphnet/training/labels.html create mode 100644 _modules/graphnet/training/loss_functions.html create mode 100644 _modules/graphnet/training/utils.html create mode 100644 _modules/graphnet/utilities/config/base_config.html create mode 100644 _modules/graphnet/utilities/config/configurable.html create mode 100644 _modules/graphnet/utilities/config/dataset_config.html create mode 100644 _modules/graphnet/utilities/config/model_config.html create mode 100644 _modules/graphnet/utilities/config/parsing.html create mode 100644 _modules/graphnet/utilities/config/training_config.html create mode 100644 _modules/graphnet/utilities/maths.html diff --git a/_modules/graphnet/data/constants.html b/_modules/graphnet/data/constants.html index cd439b37c..b6526af39 100644 --- a/_modules/graphnet/data/constants.html +++ b/_modules/graphnet/data/constants.html @@ -436,7 +436,7 @@ <h1 id="modules-graphnet-data-constants--page-root">Source code for graphnet.dat </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/dataconverter.html b/_modules/graphnet/data/dataconverter.html index aa78de84b..9be52a7d9 100644 --- a/_modules/graphnet/data/dataconverter.html +++ b/_modules/graphnet/data/dataconverter.html @@ -938,7 +938,7 @@ <h1 id="modules-graphnet-data-dataconverter--page-root">Source code for graphnet </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/dataloader.html b/_modules/graphnet/data/dataloader.html new file mode 100644 index 000000000..4de320872 --- /dev/null +++ b/_modules/graphnet/data/dataloader.html @@ -0,0 +1,457 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.data.dataloader — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" /> + <script src="../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../about.html" /> + <link rel="index" title="Index" href="../../../genindex.html" /> + <link rel="search" title="Search" href="../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/data/dataloader" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.data.dataloader </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../"versions.json"", + target_loc = "../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-data-dataloader--page-root">Source code for graphnet.data.dataloader</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Base `Dataloader` class(es) used in `graphnet`."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Union</span> + +<span class="kn">import</span> <span class="nn">torch.utils.data</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Batch</span><span class="p">,</span> <span class="n">Data</span> + +<span class="kn">from</span> <span class="nn">graphnet.data.dataset</span> <span class="kn">import</span> <span class="n">Dataset</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">DatasetConfig</span> + + +<div class="viewcode-block" id="collate_fn"> +<a class="viewcode-back" href="../../../api/graphnet.data.dataloader.html#graphnet.data.dataloader.collate_fn">[docs]</a> +<span class="k">def</span> <span class="nf">collate_fn</span><span class="p">(</span><span class="n">graphs</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Data</span><span class="p">])</span> <span class="o">-></span> <span class="n">Batch</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Remove graphs with less than two DOM hits.</span> + +<span class="sd"> Should not occur in "production.</span> +<span class="sd"> """</span> + <span class="n">graphs</span> <span class="o">=</span> <span class="p">[</span><span class="n">g</span> <span class="k">for</span> <span class="n">g</span> <span class="ow">in</span> <span class="n">graphs</span> <span class="k">if</span> <span class="n">g</span><span class="o">.</span><span class="n">n_pulses</span> <span class="o">></span> <span class="mi">1</span><span class="p">]</span> + <span class="k">return</span> <span class="n">Batch</span><span class="o">.</span><span class="n">from_data_list</span><span class="p">(</span><span class="n">graphs</span><span class="p">)</span></div> + + + +<div class="viewcode-block" id="do_shuffle"> +<a class="viewcode-back" href="../../../api/graphnet.data.dataloader.html#graphnet.data.dataloader.do_shuffle">[docs]</a> +<span class="k">def</span> <span class="nf">do_shuffle</span><span class="p">(</span><span class="n">selection_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Check whether to shuffle selection with name `selection_name`."""</span> + <span class="k">return</span> <span class="s2">"train"</span> <span class="ow">in</span> <span class="n">selection_name</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span></div> + + + +<div class="viewcode-block" id="DataLoader"> +<a class="viewcode-back" href="../../../api/graphnet.data.dataloader.html#graphnet.data.dataloader.DataLoader">[docs]</a> +<span class="k">class</span> <span class="nc">DataLoader</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Class for loading data from a `Dataset`."""</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">dataset</span><span class="p">:</span> <span class="n">Dataset</span><span class="p">,</span> + <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> + <span class="n">shuffle</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> + <span class="n">num_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> + <span class="n">persistent_workers</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> + <span class="n">collate_fn</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="n">collate_fn</span><span class="p">,</span> + <span class="n">prefetch_factor</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> + <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct `DataLoader`."""</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span> + <span class="n">dataset</span><span class="p">,</span> + <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> + <span class="n">shuffle</span><span class="o">=</span><span class="n">shuffle</span><span class="p">,</span> + <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span> + <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">,</span> + <span class="n">persistent_workers</span><span class="o">=</span><span class="n">persistent_workers</span><span class="p">,</span> + <span class="n">prefetch_factor</span><span class="o">=</span><span class="n">prefetch_factor</span><span class="p">,</span> + <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> + <span class="p">)</span> + +<div class="viewcode-block" id="DataLoader.from_dataset_config"> +<a class="viewcode-back" href="../../../api/graphnet.data.dataloader.html#graphnet.data.dataloader.DataLoader.from_dataset_config">[docs]</a> + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">from_dataset_config</span><span class="p">(</span> + <span class="bp">cls</span><span class="p">,</span> + <span class="n">config</span><span class="p">:</span> <span class="n">DatasetConfig</span><span class="p">,</span> + <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="s2">"DataLoader"</span><span class="p">,</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s2">"DataLoader"</span><span class="p">]]:</span> +<span class="w"> </span><span class="sd">"""Construct `DataLoader`s based on selections in `DatasetConfig`."""</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">selection</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span> + <span class="k">assert</span> <span class="s2">"shuffle"</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">,</span> <span class="p">(</span> + <span class="s2">"When passing a `DatasetConfig` with multiple selections, "</span> + <span class="s2">"`shuffle` is automatically inferred from the selection name, "</span> + <span class="s2">"and thus should not specified as an argument."</span> + <span class="p">)</span> + <span class="n">datasets</span> <span class="o">=</span> <span class="n">Dataset</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">datasets</span><span class="p">,</span> <span class="nb">dict</span><span class="p">)</span> + <span class="n">data_loaders</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">DataLoader</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span> + <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">dataset</span> <span class="ow">in</span> <span class="n">datasets</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> + <span class="n">data_loaders</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span> + <span class="n">dataset</span><span class="p">,</span> + <span class="n">shuffle</span><span class="o">=</span><span class="n">do_shuffle</span><span class="p">(</span><span class="n">name</span><span class="p">),</span> + <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> + <span class="p">)</span> + + <span class="k">return</span> <span class="n">data_loaders</span> + + <span class="k">else</span><span class="p">:</span> + <span class="k">assert</span> <span class="s2">"shuffle"</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">,</span> <span class="p">(</span> + <span class="s2">"When passing a `DatasetConfig` with a single selections, you "</span> + <span class="s2">"need to specify `shuffle` as an argument."</span> + <span class="p">)</span> + <span class="n">dataset</span> <span class="o">=</span> <span class="n">Dataset</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">Dataset</span><span class="p">)</span> + <span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div> +</div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/data/dataset/dataset.html b/_modules/graphnet/data/dataset/dataset.html new file mode 100644 index 000000000..f2ee86f4d --- /dev/null +++ b/_modules/graphnet/data/dataset/dataset.html @@ -0,0 +1,1084 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.data.dataset.dataset — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/data/dataset/dataset" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.data.dataset.dataset </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-data-dataset-dataset--page-root">Source code for graphnet.data.dataset.dataset</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Base :py:class:`Dataset` class(es) used in GraphNeT."""</span> + +<span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span> +<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">cast</span><span class="p">,</span> + <span class="n">Any</span><span class="p">,</span> + <span class="n">Callable</span><span class="p">,</span> + <span class="n">Dict</span><span class="p">,</span> + <span class="n">List</span><span class="p">,</span> + <span class="n">Optional</span><span class="p">,</span> + <span class="n">Tuple</span><span class="p">,</span> + <span class="n">Union</span><span class="p">,</span> + <span class="n">Iterable</span><span class="p">,</span> + <span class="n">Type</span><span class="p">,</span> +<span class="p">)</span> + +<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> + +<span class="kn">from</span> <span class="nn">graphnet.constants</span> <span class="kn">import</span> <span class="n">GRAPHNET_ROOT_DIR</span> +<span class="kn">from</span> <span class="nn">graphnet.data.utilities.string_selection_resolver</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">StringSelectionResolver</span><span class="p">,</span> +<span class="p">)</span> +<span class="kn">from</span> <span class="nn">graphnet.training.labels</span> <span class="kn">import</span> <span class="n">Label</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">Configurable</span><span class="p">,</span> + <span class="n">DatasetConfig</span><span class="p">,</span> + <span class="n">save_dataset_config</span><span class="p">,</span> +<span class="p">)</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config.parsing</span> <span class="kn">import</span> <span class="n">traverse_and_apply</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span> +<span class="kn">from</span> <span class="nn">graphnet.models.graphs</span> <span class="kn">import</span> <span class="n">GraphDefinition</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.config.parsing</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">get_all_grapnet_classes</span><span class="p">,</span> +<span class="p">)</span> + + +<div class="viewcode-block" id="ColumnMissingException"> +<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.ColumnMissingException">[docs]</a> +<span class="k">class</span> <span class="nc">ColumnMissingException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Exception to indicate a missing column in a dataset."""</span></div> + + + +<div class="viewcode-block" id="load_module"> +<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.load_module">[docs]</a> +<span class="k">def</span> <span class="nf">load_module</span><span class="p">(</span><span class="n">class_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="n">Type</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Load graphnet module from string name.</span> + +<span class="sd"> Args:</span> +<span class="sd"> class_name: name of class</span> + +<span class="sd"> Returns:</span> +<span class="sd"> graphnet module.</span> +<span class="sd"> """</span> + <span class="c1"># Get a lookup for all classes in `graphnet`</span> + <span class="kn">import</span> <span class="nn">graphnet.data</span> + <span class="kn">import</span> <span class="nn">graphnet.models</span> + <span class="kn">import</span> <span class="nn">graphnet.training</span> + + <span class="n">namespace_classes</span> <span class="o">=</span> <span class="n">get_all_grapnet_classes</span><span class="p">(</span> + <span class="n">graphnet</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">graphnet</span><span class="o">.</span><span class="n">models</span><span class="p">,</span> <span class="n">graphnet</span><span class="o">.</span><span class="n">training</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">namespace_classes</span><span class="p">[</span><span class="n">class_name</span><span class="p">]</span></div> + + + +<div class="viewcode-block" id="parse_graph_definition"> +<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.parse_graph_definition">[docs]</a> +<span class="k">def</span> <span class="nf">parse_graph_definition</span><span class="p">(</span><span class="n">cfg</span><span class="p">:</span> <span class="nb">dict</span><span class="p">)</span> <span class="o">-></span> <span class="n">GraphDefinition</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct GraphDefinition from DatasetConfig."""</span> + <span class="k">assert</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"graph_definition"</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> + + <span class="n">args</span> <span class="o">=</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"graph_definition"</span><span class="p">][</span><span class="s2">"arguments"</span><span class="p">]</span> + <span class="n">classes</span> <span class="o">=</span> <span class="p">{}</span> + <span class="k">for</span> <span class="n">arg</span> <span class="ow">in</span> <span class="n">args</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">args</span><span class="p">[</span><span class="n">arg</span><span class="p">],</span> <span class="nb">dict</span><span class="p">):</span> + <span class="k">if</span> <span class="s2">"class_name"</span> <span class="ow">in</span> <span class="n">args</span><span class="p">[</span><span class="n">arg</span><span class="p">]</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span> + <span class="n">classes</span><span class="p">[</span><span class="n">arg</span><span class="p">]</span> <span class="o">=</span> <span class="n">load_module</span><span class="p">(</span><span class="n">args</span><span class="p">[</span><span class="n">arg</span><span class="p">][</span><span class="s2">"class_name"</span><span class="p">])(</span> + <span class="o">**</span><span class="n">args</span><span class="p">[</span><span class="n">arg</span><span class="p">][</span><span class="s2">"arguments"</span><span class="p">]</span> + <span class="p">)</span> + <span class="k">if</span> <span class="n">arg</span> <span class="o">==</span> <span class="s2">"dtype"</span><span class="p">:</span> + <span class="n">args</span><span class="p">[</span><span class="n">arg</span><span class="p">]</span> <span class="o">=</span> <span class="nb">eval</span><span class="p">(</span><span class="n">args</span><span class="p">[</span><span class="n">arg</span><span class="p">])</span> <span class="c1"># converts string to class</span> + + <span class="n">new_cfg</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">args</span><span class="p">)</span> + <span class="n">new_cfg</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">classes</span><span class="p">)</span> + <span class="n">graph_definition</span> <span class="o">=</span> <span class="n">load_module</span><span class="p">(</span><span class="n">cfg</span><span class="p">[</span><span class="s2">"graph_definition"</span><span class="p">][</span><span class="s2">"class_name"</span><span class="p">])(</span> + <span class="o">**</span><span class="n">new_cfg</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">graph_definition</span></div> + + + +<div class="viewcode-block" id="Dataset"> +<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset">[docs]</a> +<span class="k">class</span> <span class="nc">Dataset</span><span class="p">(</span><span class="n">Logger</span><span class="p">,</span> <span class="n">Configurable</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="p">,</span> <span class="n">ABC</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Base Dataset class for reading from any intermediate file format."""</span> + + <span class="c1"># Class method(s)</span> +<div class="viewcode-block" id="Dataset.from_config"> +<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.from_config">[docs]</a> + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">from_config</span><span class="p">(</span> <span class="c1"># type: ignore[override]</span> + <span class="bp">cls</span><span class="p">,</span> + <span class="n">source</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">DatasetConfig</span><span class="p">,</span> <span class="nb">str</span><span class="p">],</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span> + <span class="s2">"Dataset"</span><span class="p">,</span> + <span class="s2">"EnsembleDataset"</span><span class="p">,</span> + <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s2">"Dataset"</span><span class="p">],</span> + <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s2">"EnsembleDataset"</span><span class="p">],</span> + <span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Construct `Dataset` instance from `source` configuration."""</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">source</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span> + <span class="n">source</span> <span class="o">=</span> <span class="n">DatasetConfig</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">source</span><span class="p">)</span> + + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">source</span><span class="p">,</span> <span class="n">DatasetConfig</span><span class="p">),</span> <span class="p">(</span> + <span class="sa">f</span><span class="s2">"Argument `source` of type (</span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">source</span><span class="p">)</span><span class="si">}</span><span class="s2">) is not a "</span> + <span class="s2">"`DatasetConfig`"</span> + <span class="p">)</span> + + <span class="k">assert</span> <span class="p">(</span> + <span class="s2">"graph_definition"</span> <span class="ow">in</span> <span class="n">source</span><span class="o">.</span><span class="n">dict</span><span class="p">()</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> + <span class="p">),</span> <span class="s2">"`DatasetConfig` incompatible with current GraphNeT version."</span> + + <span class="c1"># Parse set of `selection``.</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">source</span><span class="o">.</span><span class="n">selection</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span> + <span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_construct_datasets_from_dict</span><span class="p">(</span><span class="n">source</span><span class="p">)</span> + <span class="k">elif</span> <span class="p">(</span> + <span class="nb">isinstance</span><span class="p">(</span><span class="n">source</span><span class="o">.</span><span class="n">selection</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> + <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">source</span><span class="o">.</span><span class="n">selection</span><span class="p">)</span> + <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">source</span><span class="o">.</span><span class="n">selection</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nb">str</span><span class="p">)</span> + <span class="p">):</span> + <span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_construct_dataset_from_list_of_strings</span><span class="p">(</span><span class="n">source</span><span class="p">)</span> + + <span class="n">cfg</span> <span class="o">=</span> <span class="n">source</span><span class="o">.</span><span class="n">dict</span><span class="p">()</span> + <span class="k">if</span> <span class="n">cfg</span><span class="p">[</span><span class="s2">"graph_definition"</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">cfg</span><span class="p">[</span><span class="s2">"graph_definition"</span><span class="p">]</span> <span class="o">=</span> <span class="n">parse_graph_definition</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span> + <span class="k">return</span> <span class="n">source</span><span class="o">.</span><span class="n">_dataset_class</span><span class="p">(</span><span class="o">**</span><span class="n">cfg</span><span class="p">)</span></div> + + +<div class="viewcode-block" id="Dataset.concatenate"> +<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.concatenate">[docs]</a> + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">concatenate</span><span class="p">(</span> + <span class="bp">cls</span><span class="p">,</span> + <span class="n">datasets</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s2">"Dataset"</span><span class="p">],</span> + <span class="p">)</span> <span class="o">-></span> <span class="s2">"EnsembleDataset"</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Concatenate multiple `Dataset`s into one instance."""</span> + <span class="k">return</span> <span class="n">EnsembleDataset</span><span class="p">(</span><span class="n">datasets</span><span class="p">)</span></div> + + + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">_construct_datasets_from_dict</span><span class="p">(</span> + <span class="bp">cls</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">DatasetConfig</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s2">"Dataset"</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Construct `Dataset` for each entry in dict `self.selection`."""</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">selection</span><span class="p">,</span> <span class="nb">dict</span><span class="p">)</span> + <span class="n">datasets</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s2">"Dataset"</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span> + <span class="n">selections</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">]]</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">selection</span><span class="p">)</span> + <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">selection</span> <span class="ow">in</span> <span class="n">selections</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> + <span class="n">config</span><span class="o">.</span><span class="n">selection</span> <span class="o">=</span> <span class="n">selection</span> + <span class="n">dataset</span> <span class="o">=</span> <span class="n">Dataset</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="p">(</span><span class="n">Dataset</span><span class="p">,</span> <span class="n">EnsembleDataset</span><span class="p">))</span> + <span class="n">datasets</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">dataset</span> + + <span class="c1"># Reset `selections`.</span> + <span class="n">config</span><span class="o">.</span><span class="n">selection</span> <span class="o">=</span> <span class="n">selections</span> + + <span class="k">return</span> <span class="n">datasets</span> + + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">_construct_dataset_from_list_of_strings</span><span class="p">(</span> + <span class="bp">cls</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">DatasetConfig</span> + <span class="p">)</span> <span class="o">-></span> <span class="s2">"Dataset"</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct `Dataset` for each entry in list `self.selection`."""</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">selection</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> + <span class="n">datasets</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s2">"Dataset"</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span> + <span class="n">selections</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="n">config</span><span class="o">.</span><span class="n">selection</span><span class="p">))</span> + <span class="k">for</span> <span class="n">selection</span> <span class="ow">in</span> <span class="n">selections</span><span class="p">:</span> + <span class="n">config</span><span class="o">.</span><span class="n">selection</span> <span class="o">=</span> <span class="n">selection</span> + <span class="n">dataset</span> <span class="o">=</span> <span class="n">Dataset</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">Dataset</span><span class="p">)</span> + <span class="n">datasets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span> + + <span class="c1"># Reset `selections`.</span> + <span class="n">config</span><span class="o">.</span><span class="n">selection</span> <span class="o">=</span> <span class="n">selections</span> + + <span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">datasets</span><span class="p">)</span> + + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">_resolve_graphnet_paths</span><span class="p">(</span> + <span class="bp">cls</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]:</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> + <span class="k">return</span> <span class="p">[</span><span class="n">cast</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_resolve_graphnet_paths</span><span class="p">(</span><span class="n">p</span><span class="p">))</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">path</span><span class="p">]</span> + + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> + <span class="k">return</span> <span class="p">(</span> + <span class="n">path</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"$graphnet"</span><span class="p">,</span> <span class="n">GRAPHNET_ROOT_DIR</span><span class="p">)</span> + <span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"$GRAPHNET"</span><span class="p">,</span> <span class="n">GRAPHNET_ROOT_DIR</span><span class="p">)</span> + <span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"$</span><span class="si">{graphnet}</span><span class="s2">"</span><span class="p">,</span> <span class="n">GRAPHNET_ROOT_DIR</span><span class="p">)</span> + <span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"$</span><span class="si">{GRAPHNET}</span><span class="s2">"</span><span class="p">,</span> <span class="n">GRAPHNET_ROOT_DIR</span><span class="p">)</span> + <span class="p">)</span> + + <span class="nd">@save_dataset_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">path</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span> + <span class="n">graph_definition</span><span class="p">:</span> <span class="n">GraphDefinition</span><span class="p">,</span> + <span class="n">pulsemaps</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span> + <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="n">truth</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="o">*</span><span class="p">,</span> + <span class="n">node_truth</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">index_column</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"event_no"</span><span class="p">,</span> + <span class="n">truth_table</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"truth"</span><span class="p">,</span> + <span class="n">node_truth_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">string_selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">dtype</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> + <span class="n">loss_weight_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight_column</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight_default_value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct Dataset.</span> + +<span class="sd"> Args:</span> +<span class="sd"> path: Path to the file(s) from which this `Dataset` should read.</span> +<span class="sd"> pulsemaps: Name(s) of the pulse map series that should be used to</span> +<span class="sd"> construct the nodes on the individual graph objects, and their</span> +<span class="sd"> features. Multiple pulse series maps can be used, e.g., when</span> +<span class="sd"> different DOM types are stored in different maps.</span> +<span class="sd"> features: List of columns in the input files that should be used as</span> +<span class="sd"> node features on the graph objects.</span> +<span class="sd"> truth: List of event-level columns in the input files that should</span> +<span class="sd"> be used added as attributes on the graph objects.</span> +<span class="sd"> node_truth: List of node-level columns in the input files that</span> +<span class="sd"> should be used added as attributes on the graph objects.</span> +<span class="sd"> index_column: Name of the column in the input files that contains</span> +<span class="sd"> unique indicies to identify and map events across tables.</span> +<span class="sd"> truth_table: Name of the table containing event-level truth</span> +<span class="sd"> information.</span> +<span class="sd"> node_truth_table: Name of the table containing node-level truth</span> +<span class="sd"> information.</span> +<span class="sd"> string_selection: Subset of strings for which data should be read</span> +<span class="sd"> and used to construct graph objects. Defaults to None, meaning</span> +<span class="sd"> all strings for which data exists are used.</span> +<span class="sd"> selection: The events that should be read. This can be given either</span> +<span class="sd"> as list of indicies (in `index_column`); or a string-based</span> +<span class="sd"> selection used to query the `Dataset` for events passing the</span> +<span class="sd"> selection. Defaults to None, meaning that all events in the</span> +<span class="sd"> input files are read.</span> +<span class="sd"> dtype: Type of the feature tensor on the graph objects returned.</span> +<span class="sd"> loss_weight_table: Name of the table containing per-event loss</span> +<span class="sd"> weights.</span> +<span class="sd"> loss_weight_column: Name of the column in `loss_weight_table`</span> +<span class="sd"> containing per-event loss weights. This is also the name of the</span> +<span class="sd"> corresponding attribute assigned to the graph object.</span> +<span class="sd"> loss_weight_default_value: Default per-event loss weight.</span> +<span class="sd"> NOTE: This default value is only applied when</span> +<span class="sd"> `loss_weight_table` and `loss_weight_column` are specified, and</span> +<span class="sd"> in this case to events with no value in the corresponding</span> +<span class="sd"> table/column. That is, if no per-event loss weight table/column</span> +<span class="sd"> is provided, this value is ignored. Defaults to None.</span> +<span class="sd"> seed: Random number generator seed, used for selecting a random</span> +<span class="sd"> subset of events when resolving a string-based selection (e.g.,</span> +<span class="sd"> `"10000 random events ~ event_no % 5 > 0"` or `"20% random</span> +<span class="sd"> events ~ event_no % 5 > 0"`).</span> +<span class="sd"> graph_definition: Method that defines the graph representation.</span> +<span class="sd"> """</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span> + + <span class="c1"># Check(s)</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pulsemaps</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span> + <span class="n">pulsemaps</span> <span class="o">=</span> <span class="p">[</span><span class="n">pulsemaps</span><span class="p">]</span> + + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">))</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">truth</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">))</span> + + <span class="c1"># Resolve reference to `$GRAPHNET` in path(s)</span> + <span class="n">path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_resolve_graphnet_paths</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> + + <span class="c1"># Member variable(s)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_path</span> <span class="o">=</span> <span class="n">path</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_selection</span> <span class="o">=</span> <span class="kc">None</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_pulsemaps</span> <span class="o">=</span> <span class="n">pulsemaps</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_features</span> <span class="o">=</span> <span class="p">[</span><span class="n">index_column</span><span class="p">]</span> <span class="o">+</span> <span class="n">features</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_truth</span> <span class="o">=</span> <span class="p">[</span><span class="n">index_column</span><span class="p">]</span> <span class="o">+</span> <span class="n">truth</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span> <span class="o">=</span> <span class="n">index_column</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_truth_table</span> <span class="o">=</span> <span class="n">truth_table</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_default_value</span> <span class="o">=</span> <span class="n">loss_weight_default_value</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_graph_definition</span> <span class="o">=</span> <span class="n">graph_definition</span> + + <span class="k">if</span> <span class="n">node_truth</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">node_truth_table</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">node_truth</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span> + <span class="n">node_truth</span> <span class="o">=</span> <span class="p">[</span><span class="n">node_truth</span><span class="p">]</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span> <span class="o">=</span> <span class="n">node_truth</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth_table</span> <span class="o">=</span> <span class="n">node_truth_table</span> + + <span class="k">if</span> <span class="n">string_selection</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span> + <span class="p">(</span> + <span class="s2">"String selection detected.</span><span class="se">\n</span><span class="s2"> "</span> + <span class="sa">f</span><span class="s2">"Accepted strings: </span><span class="si">{</span><span class="n">string_selection</span><span class="si">}</span><span class="se">\n</span><span class="s2"> "</span> + <span class="s2">"All other strings are ignored!"</span> + <span class="p">)</span> + <span class="p">)</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">string_selection</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span> + <span class="n">string_selection</span> <span class="o">=</span> <span class="p">[</span><span class="n">string_selection</span><span class="p">]</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_string_selection</span> <span class="o">=</span> <span class="n">string_selection</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_selection</span> <span class="o">=</span> <span class="kc">None</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_string_selection</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_selection</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"string in </span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_string_selection</span><span class="p">))</span><span class="si">}</span><span class="s2">"</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_column</span> <span class="o">=</span> <span class="n">loss_weight_column</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_table</span> <span class="o">=</span> <span class="n">loss_weight_table</span> + <span class="k">if</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_table</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_column</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> + <span class="p">):</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"Error: no loss weight table specified"</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_table</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> + <span class="k">if</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_column</span> <span class="ow">is</span> <span class="kc">None</span> + <span class="p">):</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"Error: no loss weight column specified"</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_column</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_dtype</span> <span class="o">=</span> <span class="n">dtype</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_label_fns</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">Data</span><span class="p">],</span> <span class="n">Any</span><span class="p">]]</span> <span class="o">=</span> <span class="p">{}</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_string_selection_resolver</span> <span class="o">=</span> <span class="n">StringSelectionResolver</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">index_column</span><span class="o">=</span><span class="n">index_column</span><span class="p">,</span> + <span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">,</span> + <span class="p">)</span> + + <span class="c1"># Implementation-specific initialisation.</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_init</span><span class="p">()</span> + + <span class="c1"># Set unique indices</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_indices</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]]</span> + <span class="k">if</span> <span class="n">selection</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_all_indices</span><span class="p">()</span> + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">selection</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_resolve_string_selection_to_indices</span><span class="p">(</span> + <span class="n">selection</span> + <span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_indices</span> <span class="o">=</span> <span class="n">selection</span> + + <span class="c1"># Purely internal member variables</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_missing_variables</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="p">{}</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_remove_missing_columns</span><span class="p">()</span> + + <span class="c1"># Implementation-specific post-init code.</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_post_init</span><span class="p">()</span> + + <span class="c1"># Properties</span> + <span class="nd">@property</span> + <span class="k">def</span> <span class="nf">path</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]:</span> +<span class="w"> </span><span class="sd">"""Path to the file(s) from which this `Dataset` reads."""</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_path</span> + + <span class="nd">@property</span> + <span class="k">def</span> <span class="nf">truth_table</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Name of the table containing event-level truth information."""</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_truth_table</span> + + <span class="c1"># Abstract method(s)</span> + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">_init</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Set internal representation needed to read data from input file."""</span> + + <span class="k">def</span> <span class="nf">_post_init</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Implementation-specific code executed after the main constructor."""</span> + + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">_get_all_indices</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Return a list of all available values in `self._index_column`."""</span> + + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">_get_event_index</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">sequential_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Return the event index corresponding to a `sequential_index`."""</span> + +<div class="viewcode-block" id="Dataset.query_table"> +<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.query_table">[docs]</a> + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">query_table</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">table</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">columns</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">str</span><span class="p">],</span> + <span class="n">sequential_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]]:</span> +<span class="w"> </span><span class="sd">"""Query a table at a specific index, optionally with some selection.</span> + +<span class="sd"> Args:</span> +<span class="sd"> table: Table to be queried.</span> +<span class="sd"> columns: Columns to read out.</span> +<span class="sd"> sequential_index: Sequentially numbered index</span> +<span class="sd"> (i.e. in [0,len(self))) of the event to query. This _may_</span> +<span class="sd"> differ from the indexation used in `self._indices`. If no value</span> +<span class="sd"> is provided, the entire column is returned.</span> +<span class="sd"> selection: Selection to be imposed before reading out data.</span> +<span class="sd"> Defaults to None.</span> + +<span class="sd"> Returns:</span> +<span class="sd"> List of tuples containing the values in `columns`. If the `table`</span> +<span class="sd"> contains only scalar data for `columns`, a list of length 1 is</span> +<span class="sd"> returned</span> + +<span class="sd"> Raises:</span> +<span class="sd"> ColumnMissingException: If one or more element in `columns` is not</span> +<span class="sd"> present in `table`.</span> +<span class="sd"> """</span></div> + + + <span class="c1"># Public method(s)</span> +<div class="viewcode-block" id="Dataset.add_label"> +<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.add_label">[docs]</a> + <span class="k">def</span> <span class="nf">add_label</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">fn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">Data</span><span class="p">],</span> <span class="n">Any</span><span class="p">],</span> <span class="n">key</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> + <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Add custom graph label define using function `fn`."""</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">Label</span><span class="p">):</span> + <span class="n">key</span> <span class="o">=</span> <span class="n">fn</span><span class="o">.</span><span class="n">key</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span> + <span class="n">key</span><span class="p">,</span> <span class="nb">str</span> + <span class="p">),</span> <span class="s2">"Please specify a key for the custom label to be added."</span> + <span class="k">assert</span> <span class="p">(</span> + <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_label_fns</span> + <span class="p">),</span> <span class="sa">f</span><span class="s2">"A custom label </span><span class="si">{</span><span class="n">key</span><span class="si">}</span><span class="s2"> has already been defined."</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_label_fns</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">fn</span></div> + + + <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return number of graphs in `Dataset`."""</span> + <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_indices</span><span class="p">)</span> + + <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sequential_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return graph `Data` object at `index`."""</span> + <span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="mi">0</span> <span class="o"><=</span> <span class="n">sequential_index</span> <span class="o"><</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">)):</span> + <span class="k">raise</span> <span class="ne">IndexError</span><span class="p">(</span> + <span class="sa">f</span><span class="s2">"Index </span><span class="si">{</span><span class="n">sequential_index</span><span class="si">}</span><span class="s2"> not in range [0, </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="mi">1</span><span class="si">}</span><span class="s2">]"</span> + <span class="p">)</span> + <span class="n">features</span><span class="p">,</span> <span class="n">truth</span><span class="p">,</span> <span class="n">node_truth</span><span class="p">,</span> <span class="n">loss_weight</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_query</span><span class="p">(</span> + <span class="n">sequential_index</span> + <span class="p">)</span> + <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_graph</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">truth</span><span class="p">,</span> <span class="n">node_truth</span><span class="p">,</span> <span class="n">loss_weight</span><span class="p">)</span> + <span class="k">return</span> <span class="n">graph</span> + + <span class="c1"># Internal method(s)</span> + <span class="k">def</span> <span class="nf">_resolve_string_selection_to_indices</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">selection</span><span class="p">:</span> <span class="nb">str</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Resolve selection as string to list of indices.</span> + +<span class="sd"> Selections are expected to have pandas.DataFrame.query-compatible</span> +<span class="sd"> syntax, e.g., ``` "event_no % 5 > 0" ``` Selections may also specify a</span> +<span class="sd"> fixed number of events to randomly sample, e.g., ``` "10000 random</span> +<span class="sd"> events ~ event_no % 5 > 0" "20% random events ~ event_no % 5 > 0" ```</span> +<span class="sd"> """</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_string_selection_resolver</span><span class="o">.</span><span class="n">resolve</span><span class="p">(</span><span class="n">selection</span><span class="p">)</span> + + <span class="k">def</span> <span class="nf">_remove_missing_columns</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Remove columns that are not present in the input file.</span> + +<span class="sd"> Columns are removed from `self._features` and `self._truth`.</span> +<span class="sd"> """</span> + <span class="c1"># Check if table is completely empty</span> + <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"Dataset is empty."</span><span class="p">)</span> + <span class="k">return</span> + + <span class="c1"># Find missing features</span> + <span class="n">missing_features_set</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">)</span> + <span class="k">for</span> <span class="n">pulsemap</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pulsemaps</span><span class="p">:</span> + <span class="n">missing</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_missing_columns</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">,</span> <span class="n">pulsemap</span><span class="p">)</span> + <span class="n">missing_features_set</span> <span class="o">=</span> <span class="n">missing_features_set</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="n">missing</span><span class="p">)</span> + + <span class="n">missing_features</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">missing_features_set</span><span class="p">)</span> + + <span class="c1"># Find missing truth variables</span> + <span class="n">missing_truth_variables</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_missing_columns</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_truth</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_truth_table</span> + <span class="p">)</span> + + <span class="c1"># Remove missing features</span> + <span class="k">if</span> <span class="n">missing_features</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span> + <span class="s2">"Removing the following (missing) features: "</span> + <span class="o">+</span> <span class="s2">", "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">missing_features</span><span class="p">)</span> + <span class="p">)</span> + <span class="k">for</span> <span class="n">missing_feature</span> <span class="ow">in</span> <span class="n">missing_features</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">missing_feature</span><span class="p">)</span> + + <span class="c1"># Remove missing truth variables</span> + <span class="k">if</span> <span class="n">missing_truth_variables</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span> + <span class="p">(</span> + <span class="s2">"Removing the following (missing) truth variables: "</span> + <span class="o">+</span> <span class="s2">", "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">missing_truth_variables</span><span class="p">)</span> + <span class="p">)</span> + <span class="p">)</span> + <span class="k">for</span> <span class="n">missing_truth_variable</span> <span class="ow">in</span> <span class="n">missing_truth_variables</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_truth</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">missing_truth_variable</span><span class="p">)</span> + + <span class="k">def</span> <span class="nf">_check_missing_columns</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="n">table</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Return a list missing columns in `table`."""</span> + <span class="k">for</span> <span class="n">column</span> <span class="ow">in</span> <span class="n">columns</span><span class="p">:</span> + <span class="k">try</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">query_table</span><span class="p">(</span><span class="n">table</span><span class="p">,</span> <span class="p">[</span><span class="n">column</span><span class="p">],</span> <span class="mi">0</span><span class="p">)</span> + <span class="k">except</span> <span class="n">ColumnMissingException</span><span class="p">:</span> + <span class="k">if</span> <span class="n">table</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_missing_variables</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_missing_variables</span><span class="p">[</span><span class="n">table</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_missing_variables</span><span class="p">[</span><span class="n">table</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">column</span><span class="p">)</span> + <span class="k">except</span> <span class="ne">IndexError</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Dataset contains no entries for </span><span class="si">{</span><span class="n">column</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> + + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_missing_variables</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">table</span><span class="p">,</span> <span class="p">[])</span> + + <span class="k">def</span> <span class="nf">_query</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">sequential_index</span><span class="p">:</span> <span class="nb">int</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span> + <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="o">...</span><span class="p">]],</span> + <span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> + <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]]],</span> + <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span> + <span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Query file for event features and truth information.</span> + +<span class="sd"> The returned lists have lengths corresponding to the number of pulses</span> +<span class="sd"> in the event. Their constituent tuples have lengths corresponding to</span> +<span class="sd"> the number of features/attributes in each output</span> + +<span class="sd"> Args:</span> +<span class="sd"> sequential_index: Sequentially numbered index</span> +<span class="sd"> (i.e. in [0,len(self))) of the event to query. This _may_</span> +<span class="sd"> differ from the indexation used in `self._indices`.</span> + +<span class="sd"> Returns:</span> +<span class="sd"> Tuple containing pulse-level event features; event-level truth</span> +<span class="sd"> information; pulse-level truth information; and event-level</span> +<span class="sd"> loss weights, respectively.</span> +<span class="sd"> """</span> + <span class="n">features</span> <span class="o">=</span> <span class="p">[]</span> + <span class="k">for</span> <span class="n">pulsemap</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pulsemaps</span><span class="p">:</span> + <span class="n">features_pulsemap</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_table</span><span class="p">(</span> + <span class="n">pulsemap</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">,</span> <span class="n">sequential_index</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_selection</span> + <span class="p">)</span> + <span class="n">features</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">features_pulsemap</span><span class="p">)</span> + + <span class="n">truth</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_table</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_truth_table</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_truth</span><span class="p">,</span> <span class="n">sequential_index</span> + <span class="p">)[</span><span class="mi">0</span><span class="p">]</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span><span class="p">:</span> + <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> + <span class="n">node_truth</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_table</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth_table</span><span class="p">,</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span><span class="p">,</span> + <span class="n">sequential_index</span><span class="p">,</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_selection</span><span class="p">,</span> + <span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">node_truth</span> <span class="o">=</span> <span class="kc">None</span> + + <span class="n">loss_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># Default</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_column</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> + <span class="n">loss_weight_list</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_table</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_table</span><span class="p">,</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_column</span><span class="p">,</span> + <span class="n">sequential_index</span><span class="p">,</span> + <span class="p">)</span> + <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">loss_weight_list</span><span class="p">):</span> + <span class="n">loss_weight</span> <span class="o">=</span> <span class="n">loss_weight_list</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">loss_weight</span> <span class="o">=</span> <span class="o">-</span><span class="mf">1.0</span> + + <span class="k">return</span> <span class="n">features</span><span class="p">,</span> <span class="n">truth</span><span class="p">,</span> <span class="n">node_truth</span><span class="p">,</span> <span class="n">loss_weight</span> + + <span class="k">def</span> <span class="nf">_create_graph</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="o">...</span><span class="p">]],</span> + <span class="n">truth</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> + <span class="n">node_truth</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Create Pytorch Data (i.e. graph) object.</span> + +<span class="sd"> Args:</span> +<span class="sd"> features: List of tuples, containing event features.</span> +<span class="sd"> truth: List of tuples, containing truth information.</span> +<span class="sd"> node_truth: List of tuples, containing node-level truth.</span> +<span class="sd"> loss_weight: A weight associated with the event for weighing the</span> +<span class="sd"> loss.</span> + +<span class="sd"> Returns:</span> +<span class="sd"> Graph object.</span> +<span class="sd"> """</span> + <span class="c1"># Convert nested list to simple dict</span> + <span class="n">truth_dict</span> <span class="o">=</span> <span class="p">{</span> + <span class="n">key</span><span class="p">:</span> <span class="n">truth</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">key</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_truth</span><span class="p">)</span> + <span class="p">}</span> + + <span class="c1"># Define custom labels</span> + <span class="n">labels_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_labels</span><span class="p">(</span><span class="n">truth_dict</span><span class="p">)</span> + + <span class="c1"># Convert nested list to simple dict</span> + <span class="k">if</span> <span class="n">node_truth</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">node_truth_array</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">node_truth</span><span class="p">)</span> + <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> + <span class="n">node_truth_dict</span> <span class="o">=</span> <span class="p">{</span> + <span class="n">key</span><span class="p">:</span> <span class="n">node_truth_array</span><span class="p">[:,</span> <span class="n">index</span><span class="p">]</span> + <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">key</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span><span class="p">)</span> + <span class="p">}</span> + + <span class="c1"># Create list of truth dicts with labels</span> + <span class="n">truth_dicts</span> <span class="o">=</span> <span class="p">[</span><span class="n">labels_dict</span><span class="p">,</span> <span class="n">truth_dict</span><span class="p">]</span> + <span class="k">if</span> <span class="n">node_truth</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">truth_dicts</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">node_truth_dict</span><span class="p">)</span> + + <span class="c1"># Catch cases with no reconstructed pulses</span> + <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">features</span><span class="p">):</span> + <span class="n">node_features</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">features</span><span class="p">)[</span> + <span class="p">:,</span> <span class="mi">1</span><span class="p">:</span> + <span class="p">]</span> <span class="c1"># first entry is index column</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">node_features</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([])</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span> + + <span class="c1"># Construct graph data object</span> + <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_graph_definition</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> + <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_graph_definition</span><span class="p">(</span> + <span class="n">node_features</span><span class="o">=</span><span class="n">node_features</span><span class="p">,</span> + <span class="n">node_feature_names</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">[</span> + <span class="mi">1</span><span class="p">:</span> + <span class="p">],</span> <span class="c1"># first entry is index column</span> + <span class="n">truth_dicts</span><span class="o">=</span><span class="n">truth_dicts</span><span class="p">,</span> + <span class="n">custom_label_functions</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_label_fns</span><span class="p">,</span> + <span class="n">loss_weight_column</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_column</span><span class="p">,</span> + <span class="n">loss_weight</span><span class="o">=</span><span class="n">loss_weight</span><span class="p">,</span> + <span class="n">loss_weight_default_value</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_default_value</span><span class="p">,</span> + <span class="n">data_path</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">graph</span> + + <span class="k">def</span> <span class="nf">_get_labels</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">truth_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Return dictionary of labels, to be added as graph attributes."""</span> + <span class="k">if</span> <span class="s2">"pid"</span> <span class="ow">in</span> <span class="n">truth_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span> + <span class="n">abs_pid</span> <span class="o">=</span> <span class="nb">abs</span><span class="p">(</span><span class="n">truth_dict</span><span class="p">[</span><span class="s2">"pid"</span><span class="p">])</span> + <span class="n">sim_type</span> <span class="o">=</span> <span class="n">truth_dict</span><span class="p">[</span><span class="s2">"sim_type"</span><span class="p">]</span> + + <span class="n">labels_dict</span> <span class="o">=</span> <span class="p">{</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span><span class="p">:</span> <span class="n">truth_dict</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span><span class="p">],</span> + <span class="s2">"muon"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">abs_pid</span> <span class="o">==</span> <span class="mi">13</span><span class="p">),</span> + <span class="s2">"muon_stopped"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">truth_dict</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"stopped_muon"</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">),</span> + <span class="s2">"noise"</span><span class="p">:</span> <span class="nb">int</span><span class="p">((</span><span class="n">abs_pid</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">sim_type</span> <span class="o">!=</span> <span class="s2">"data"</span><span class="p">)),</span> + <span class="s2">"neutrino"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span> + <span class="p">(</span><span class="n">abs_pid</span> <span class="o">!=</span> <span class="mi">13</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">abs_pid</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">)</span> + <span class="p">),</span> <span class="c1"># @TODO: `abs_pid in [12,14,16]`?</span> + <span class="s2">"v_e"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">abs_pid</span> <span class="o">==</span> <span class="mi">12</span><span class="p">),</span> + <span class="s2">"v_u"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">abs_pid</span> <span class="o">==</span> <span class="mi">14</span><span class="p">),</span> + <span class="s2">"v_t"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">abs_pid</span> <span class="o">==</span> <span class="mi">16</span><span class="p">),</span> + <span class="s2">"track"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span> + <span class="p">(</span><span class="n">abs_pid</span> <span class="o">==</span> <span class="mi">14</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">truth_dict</span><span class="p">[</span><span class="s2">"interaction_type"</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span> + <span class="p">),</span> + <span class="s2">"dbang"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_dbang_label</span><span class="p">(</span><span class="n">truth_dict</span><span class="p">),</span> + <span class="s2">"corsika"</span><span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">abs_pid</span> <span class="o">></span> <span class="mi">20</span><span class="p">),</span> + <span class="p">}</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">labels_dict</span> <span class="o">=</span> <span class="p">{</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span><span class="p">:</span> <span class="n">truth_dict</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span><span class="p">],</span> + <span class="s2">"muon"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> + <span class="s2">"muon_stopped"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> + <span class="s2">"noise"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> + <span class="s2">"neutrino"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> + <span class="s2">"v_e"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> + <span class="s2">"v_u"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> + <span class="s2">"v_t"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> + <span class="s2">"track"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> + <span class="s2">"dbang"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> + <span class="s2">"corsika"</span><span class="p">:</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> + <span class="p">}</span> + <span class="k">return</span> <span class="n">labels_dict</span> + + <span class="k">def</span> <span class="nf">_get_dbang_label</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">truth_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Get label for double-bang classification."""</span> + <span class="k">try</span><span class="p">:</span> + <span class="n">label</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">truth_dict</span><span class="p">[</span><span class="s2">"dbang_decay_length"</span><span class="p">]</span> <span class="o">></span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> + <span class="k">return</span> <span class="n">label</span> + <span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span> + <span class="k">return</span> <span class="o">-</span><span class="mi">1</span></div> + + + +<div class="viewcode-block" id="EnsembleDataset"> +<a class="viewcode-back" href="../../../../api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.EnsembleDataset">[docs]</a> +<span class="k">class</span> <span class="nc">EnsembleDataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">ConcatDataset</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct a single dataset from a collection of datasets."""</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">datasets</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="n">Dataset</span><span class="p">])</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct a single dataset from a collection of datasets.</span> + +<span class="sd"> Args:</span> +<span class="sd"> datasets: A collection of Datasets</span> +<span class="sd"> """</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">datasets</span><span class="o">=</span><span class="n">datasets</span><span class="p">)</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/data/dataset/parquet/parquet_dataset.html b/_modules/graphnet/data/dataset/parquet/parquet_dataset.html new file mode 100644 index 000000000..35d6b34d7 --- /dev/null +++ b/_modules/graphnet/data/dataset/parquet/parquet_dataset.html @@ -0,0 +1,500 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.data.dataset.parquet.parquet_dataset — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../../_static/material.css?v=79c92029" /> + <script src="../../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../../about.html" /> + <link rel="index" title="Index" href="../../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/data/dataset/parquet/parquet_dataset" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.data.dataset.parquet.parquet_dataset </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../../"versions.json"", + target_loc = "../../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-data-dataset-parquet-parquet-dataset--page-root">Source code for graphnet.data.dataset.parquet.parquet_dataset</h1><div class="highlight"><pre> +<span></span><span class="sd">"""`Dataset` class(es) for reading from Parquet files."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">cast</span> + +<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> +<span class="kn">import</span> <span class="nn">awkward</span> <span class="k">as</span> <span class="nn">ak</span> + +<span class="kn">from</span> <span class="nn">graphnet.data.dataset.dataset</span> <span class="kn">import</span> <span class="n">Dataset</span><span class="p">,</span> <span class="n">ColumnMissingException</span> + + +<div class="viewcode-block" id="ParquetDataset"> +<a class="viewcode-back" href="../../../../../api/graphnet.data.dataset.parquet.parquet_dataset.html#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset">[docs]</a> +<span class="k">class</span> <span class="nc">ParquetDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Pytorch dataset for reading from Parquet files."""</span> + + <span class="c1"># Implementing abstract method(s)</span> + <span class="k">def</span> <span class="nf">_init</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> + <span class="c1"># Check(s)</span> + <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> + + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> + + <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span> + <span class="s2">".parquet"</span> + <span class="p">),</span> <span class="sa">f</span><span class="s2">"Format of input file `</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="si">}</span><span class="s2">` is not supported"</span> + + <span class="k">assert</span> <span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span> <span class="ow">is</span> <span class="kc">None</span> + <span class="p">),</span> <span class="s2">"Argument `node_truth` is currently not supported."</span> + <span class="k">assert</span> <span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth_table</span> <span class="ow">is</span> <span class="kc">None</span> + <span class="p">),</span> <span class="s2">"Argument `node_truth_table` is currently not supported."</span> + <span class="k">assert</span> <span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_string_selection</span> <span class="ow">is</span> <span class="kc">None</span> + <span class="p">),</span> <span class="s2">"Argument `string_selection` is currently not supported"</span> + + <span class="c1"># Set custom member variable(s)</span> + <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_parquet_hook</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">from_parquet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> <span class="n">lazy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_parquet_hook</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span> + <span class="n">ak</span><span class="o">.</span><span class="n">from_parquet</span><span class="p">(</span><span class="n">file</span><span class="p">)</span> <span class="k">for</span> <span class="n">file</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_path</span> + <span class="p">)</span> + + <span class="k">def</span> <span class="nf">_get_all_indices</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span> + <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span> + <span class="nb">len</span><span class="p">(</span> + <span class="n">ak</span><span class="o">.</span><span class="n">to_numpy</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_parquet_hook</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_truth_table</span><span class="p">][</span><span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span><span class="p">]</span> + <span class="p">)</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> + <span class="p">)</span> + <span class="p">)</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> + + <span class="k">def</span> <span class="nf">_get_event_index</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">sequential_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span> + <span class="n">index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> + <span class="k">if</span> <span class="n">sequential_index</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">index</span> <span class="o">=</span> <span class="kc">None</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">index</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">_indices</span><span class="p">)[</span><span class="n">sequential_index</span><span class="p">]</span> + + <span class="k">return</span> <span class="n">index</span> + + <span class="k">def</span> <span class="nf">_format_dictionary_result</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">dictionary</span><span class="p">:</span> <span class="n">Dict</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]]:</span> +<span class="w"> </span><span class="sd">"""Convert the output of `ak.to_list()` into a list of tuples."""</span> + <span class="c1"># All scalar values</span> + <span class="k">if</span> <span class="nb">all</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">isscalar</span><span class="p">,</span> <span class="n">dictionary</span><span class="o">.</span><span class="n">values</span><span class="p">())):</span> + <span class="k">return</span> <span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="n">dictionary</span><span class="o">.</span><span class="n">values</span><span class="p">())]</span> + + <span class="c1"># All arrays should have same length</span> + <span class="n">array_lengths</span> <span class="o">=</span> <span class="p">[</span> + <span class="nb">len</span><span class="p">(</span><span class="n">values</span><span class="p">)</span> + <span class="k">for</span> <span class="n">values</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="o">.</span><span class="n">values</span><span class="p">()</span> + <span class="k">if</span> <span class="ow">not</span> <span class="n">np</span><span class="o">.</span><span class="n">isscalar</span><span class="p">(</span><span class="n">values</span><span class="p">)</span> + <span class="p">]</span> + <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">array_lengths</span><span class="p">))</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="p">(</span> + <span class="sa">f</span><span class="s2">"Arrays in </span><span class="si">{</span><span class="n">dictionary</span><span class="si">}</span><span class="s2"> have differing lengths "</span> + <span class="sa">f</span><span class="s2">"(</span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">array_lengths</span><span class="p">)</span><span class="si">}</span><span class="s2">)."</span> + <span class="p">)</span> + <span class="n">nb_elements</span> <span class="o">=</span> <span class="n">array_lengths</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> + + <span class="c1"># Broadcast scalars</span> + <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">:</span> + <span class="n">value</span> <span class="o">=</span> <span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> + <span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">isscalar</span><span class="p">(</span><span class="n">value</span><span class="p">):</span> + <span class="n">dictionary</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span> + <span class="n">value</span><span class="p">,</span> <span class="n">repeats</span><span class="o">=</span><span class="n">nb_elements</span> + <span class="p">)</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> + + <span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">dictionary</span><span class="o">.</span><span class="n">values</span><span class="p">()))))</span> + +<div class="viewcode-block" id="ParquetDataset.query_table"> +<a class="viewcode-back" href="../../../../../api/graphnet.data.dataset.parquet.parquet_dataset.html#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table">[docs]</a> + <span class="k">def</span> <span class="nf">query_table</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">table</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">columns</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">str</span><span class="p">],</span> + <span class="n">sequential_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]]:</span> +<span class="w"> </span><span class="sd">"""Query table at a specific index, optionally with some selection."""</span> + <span class="c1"># Check(s)</span> + <span class="k">assert</span> <span class="p">(</span> + <span class="n">selection</span> <span class="ow">is</span> <span class="kc">None</span> + <span class="p">),</span> <span class="s2">"Argument `selection` is currently not supported"</span> + + <span class="n">index</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_event_index</span><span class="p">(</span><span class="n">sequential_index</span><span class="p">)</span> + + <span class="k">try</span><span class="p">:</span> + <span class="k">if</span> <span class="n">index</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">ak_array</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_parquet_hook</span><span class="p">[</span><span class="n">table</span><span class="p">][</span><span class="n">columns</span><span class="p">][:]</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">ak_array</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_parquet_hook</span><span class="p">[</span><span class="n">table</span><span class="p">][</span><span class="n">columns</span><span class="p">][</span><span class="n">index</span><span class="p">]</span> + <span class="k">except</span> <span class="ne">ValueError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span> + <span class="k">if</span> <span class="s2">"does not exist (not in record)"</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="n">e</span><span class="p">):</span> + <span class="k">raise</span> <span class="n">ColumnMissingException</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">e</span><span class="p">))</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">raise</span> <span class="n">e</span> + + <span class="n">output</span> <span class="o">=</span> <span class="n">ak_array</span><span class="o">.</span><span class="n">to_list</span><span class="p">()</span> + + <span class="n">result</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]]</span> <span class="o">=</span> <span class="p">[]</span> + + <span class="c1"># Querying single index</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span> + <span class="k">assert</span> <span class="nb">list</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="o">==</span> <span class="n">columns</span> + <span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_format_dictionary_result</span><span class="p">(</span><span class="n">output</span><span class="p">)</span> + + <span class="c1"># Querying entire columm</span> + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> + <span class="k">for</span> <span class="n">dictionary</span> <span class="ow">in</span> <span class="n">output</span><span class="p">:</span> + <span class="k">assert</span> <span class="nb">list</span><span class="p">(</span><span class="n">dictionary</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="o">==</span> <span class="n">columns</span> + <span class="n">result</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_format_dictionary_result</span><span class="p">(</span><span class="n">dictionary</span><span class="p">))</span> + + <span class="k">return</span> <span class="n">result</span></div> +</div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html b/_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html new file mode 100644 index 000000000..a6734e1c8 --- /dev/null +++ b/_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html @@ -0,0 +1,515 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.data.dataset.sqlite.sqlite_dataset — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../../_static/material.css?v=79c92029" /> + <script src="../../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../../about.html" /> + <link rel="index" title="Index" href="../../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/data/dataset/sqlite/sqlite_dataset" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.data.dataset.sqlite.sqlite_dataset </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../../"versions.json"", + target_loc = "../../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-data-dataset-sqlite-sqlite-dataset--page-root">Source code for graphnet.data.dataset.sqlite.sqlite_dataset</h1><div class="highlight"><pre> +<span></span><span class="sd">"""`Dataset` class(es) for reading data from SQLite databases."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span> +<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> +<span class="kn">import</span> <span class="nn">sqlite3</span> + +<span class="kn">from</span> <span class="nn">graphnet.data.dataset.dataset</span> <span class="kn">import</span> <span class="n">Dataset</span><span class="p">,</span> <span class="n">ColumnMissingException</span> + + +<div class="viewcode-block" id="SQLiteDataset"> +<a class="viewcode-back" href="../../../../../api/graphnet.data.dataset.sqlite.sqlite_dataset.html#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset">[docs]</a> +<span class="k">class</span> <span class="nc">SQLiteDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Pytorch dataset for reading data from SQLite databases."""</span> + + <span class="c1"># Implementing abstract method(s)</span> + <span class="k">def</span> <span class="nf">_init</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> + <span class="c1"># Check(s)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_path</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections_established</span> <span class="o">=</span> <span class="kc">False</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">sqlite3</span><span class="o">.</span><span class="n">Connection</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span> + <span class="k">else</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span> <span class="o">=</span> <span class="kc">None</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> + <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span> + <span class="s2">".db"</span> + <span class="p">),</span> <span class="sa">f</span><span class="s2">"Format of input file `</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="si">}</span><span class="s2">` is not supported."</span> + + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_current_database</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> + + <span class="c1"># Set custom member variable(s)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_features_string</span> <span class="o">=</span> <span class="s2">", "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_truth_string</span> <span class="o">=</span> <span class="s2">", "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_truth</span><span class="p">)</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_node_truth_string</span> <span class="o">=</span> <span class="s2">", "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_node_truth</span><span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">sqlite3</span><span class="o">.</span><span class="n">Connection</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> + + <span class="k">def</span> <span class="nf">_post_init</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_close_connection</span><span class="p">()</span> + +<div class="viewcode-block" id="SQLiteDataset.query_table"> +<a class="viewcode-back" href="../../../../../api/graphnet.data.dataset.sqlite.sqlite_dataset.html#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table">[docs]</a> + <span class="k">def</span> <span class="nf">query_table</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">table</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">columns</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">str</span><span class="p">],</span> + <span class="n">sequential_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="o">...</span><span class="p">]]:</span> +<span class="w"> </span><span class="sd">"""Query table at a specific index, optionally with some selection."""</span> + <span class="c1"># Check(s)</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">columns</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> + <span class="n">columns</span> <span class="o">=</span> <span class="s2">", "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">columns</span><span class="p">)</span> + + <span class="k">if</span> <span class="ow">not</span> <span class="n">selection</span><span class="p">:</span> <span class="c1"># I.e., `None` or `""`</span> + <span class="n">selection</span> <span class="o">=</span> <span class="s2">"1=1"</span> <span class="c1"># Identically true, to select all</span> + + <span class="n">index</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_event_index</span><span class="p">(</span><span class="n">sequential_index</span><span class="p">)</span> + + <span class="c1"># Query table</span> + <span class="k">assert</span> <span class="n">index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_establish_connection</span><span class="p">(</span><span class="n">index</span><span class="p">)</span> + <span class="k">try</span><span class="p">:</span> + <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> + <span class="k">if</span> <span class="n">sequential_index</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">combined_selections</span> <span class="o">=</span> <span class="n">selection</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">combined_selections</span> <span class="o">=</span> <span class="p">(</span> + <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="n">index</span><span class="si">}</span><span class="s2"> and </span><span class="si">{</span><span class="n">selection</span><span class="si">}</span><span class="s2">"</span> + <span class="p">)</span> + + <span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span><span class="o">.</span><span class="n">execute</span><span class="p">(</span> + <span class="sa">f</span><span class="s2">"SELECT </span><span class="si">{</span><span class="n">columns</span><span class="si">}</span><span class="s2"> FROM </span><span class="si">{</span><span class="n">table</span><span class="si">}</span><span class="s2"> WHERE "</span> + <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">combined_selections</span><span class="si">}</span><span class="s2">"</span> + <span class="p">)</span><span class="o">.</span><span class="n">fetchall</span><span class="p">()</span> + <span class="k">except</span> <span class="n">sqlite3</span><span class="o">.</span><span class="n">OperationalError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span> + <span class="k">if</span> <span class="s2">"no such column"</span> <span class="ow">in</span> <span class="nb">str</span><span class="p">(</span><span class="n">e</span><span class="p">):</span> + <span class="k">raise</span> <span class="n">ColumnMissingException</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">e</span><span class="p">))</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">raise</span> <span class="n">e</span> + <span class="k">return</span> <span class="n">result</span></div> + + + <span class="k">def</span> <span class="nf">_get_all_indices</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_establish_connection</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> + <span class="n">indices</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_sql_query</span><span class="p">(</span> + <span class="sa">f</span><span class="s2">"SELECT </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_index_column</span><span class="si">}</span><span class="s2"> FROM </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_truth_table</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> + <span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_close_connection</span><span class="p">()</span> + <span class="k">return</span> <span class="n">indices</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> + + <span class="k">def</span> <span class="nf">_get_event_index</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">sequential_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span> + <span class="n">index</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span> + <span class="k">if</span> <span class="n">sequential_index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">index_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_indices</span><span class="p">[</span><span class="n">sequential_index</span><span class="p">]</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">index_</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> + <span class="n">index</span> <span class="o">=</span> <span class="n">index_</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">index_</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> + <span class="n">index</span> <span class="o">=</span> <span class="n">index_</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> + <span class="k">return</span> <span class="n">index</span> + + <span class="c1"># Custom, internal method(s)</span> + <span class="c1"># @TODO: Is it necessary to return anything here?</span> + <span class="k">def</span> <span class="nf">_establish_connection</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">i</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"SQLiteDataset"</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Make sure that a sqlite3 connection is open."""</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="o">=</span> <span class="n">sqlite3</span><span class="o">.</span><span class="n">connect</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_path</span><span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_indices</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">indices</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections_established</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections</span> <span class="o">=</span> <span class="p">[]</span> + <span class="k">for</span> <span class="n">database</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span><span class="p">:</span> + <span class="n">con</span> <span class="o">=</span> <span class="n">sqlite3</span><span class="o">.</span><span class="n">connect</span><span class="p">(</span><span class="n">database</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">con</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections_established</span> <span class="o">=</span> <span class="kc">True</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections</span><span class="p">[</span><span class="n">indices</span><span class="p">[</span><span class="mi">1</span><span class="p">]]</span> + <span class="k">if</span> <span class="n">indices</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_current_database</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections</span><span class="p">[</span><span class="n">indices</span><span class="p">[</span><span class="mi">1</span><span class="p">]]</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_current_database</span> <span class="o">=</span> <span class="n">indices</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> + <span class="k">return</span> <span class="bp">self</span> + + <span class="c1"># @TODO: Is it necessary to return anything here?</span> + <span class="k">def</span> <span class="nf">_close_connection</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"SQLiteDataset"</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Make sure that no sqlite3 connection is open.</span> + +<span class="sd"> This is necessary to calls this before passing to</span> +<span class="sd"> `torch.DataLoader` such that the dataset replica on each worker</span> +<span class="sd"> is required to create its own connection (thereby avoiding</span> +<span class="sd"> `sqlite3.DatabaseError: database disk image is malformed` errors</span> +<span class="sd"> due to inability to use sqlite3 connection accross processes.</span> +<span class="sd"> """</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span><span class="o">.</span><span class="n">close</span><span class="p">()</span> + <span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="o">=</span> <span class="kc">None</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_database_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections_established</span><span class="p">:</span> + <span class="k">for</span> <span class="n">con</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections</span><span class="p">:</span> + <span class="n">con</span><span class="o">.</span><span class="n">close</span><span class="p">()</span> + <span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_all_connections_established</span> <span class="o">=</span> <span class="kc">False</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_conn</span> <span class="o">=</span> <span class="kc">None</span> + <span class="k">return</span> <span class="bp">self</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html b/_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html new file mode 100644 index 000000000..44acbb9c5 --- /dev/null +++ b/_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html @@ -0,0 +1,515 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.data.dataset.sqlite.sqlite_dataset_perturbed — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../../_static/material.css?v=79c92029" /> + <script src="../../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../../about.html" /> + <link rel="index" title="Index" href="../../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.data.dataset.sqlite.sqlite_dataset_perturbed </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../../"versions.json"", + target_loc = "../../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root">Source code for graphnet.data.dataset.sqlite.sqlite_dataset_perturbed</h1><div class="highlight"><pre> +<span></span><span class="sd">"""`Dataset` class(es) for reading perturbed data from SQLite databases."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span> + +<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> +<span class="kn">from</span> <span class="nn">numpy.random</span> <span class="kn">import</span> <span class="n">default_rng</span><span class="p">,</span> <span class="n">Generator</span> +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> + +<span class="kn">from</span> <span class="nn">.sqlite_dataset</span> <span class="kn">import</span> <span class="n">SQLiteDataset</span> + + +<div class="viewcode-block" id="SQLiteDatasetPerturbed"> +<a class="viewcode-back" href="../../../../../api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html#graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed">[docs]</a> +<span class="k">class</span> <span class="nc">SQLiteDatasetPerturbed</span><span class="p">(</span><span class="n">SQLiteDataset</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Pytorch dataset for reading perturbed data from SQLite databases.</span> + +<span class="sd"> This including a pre-processing step, where the input data is randomly</span> +<span class="sd"> perturbed according to given per-feature "noise" levels. This is intended</span> +<span class="sd"> to test the stability of a trained model under small changes to the input</span> +<span class="sd"> parameters.</span> +<span class="sd"> """</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">path</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span> + <span class="n">pulsemaps</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span> + <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="n">truth</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="o">*</span><span class="p">,</span> + <span class="n">perturbation_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> + <span class="n">node_truth</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">index_column</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"event_no"</span><span class="p">,</span> + <span class="n">truth_table</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"truth"</span><span class="p">,</span> + <span class="n">node_truth_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">string_selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">dtype</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> + <span class="n">loss_weight_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight_column</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight_default_value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Generator</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct SQLiteDatasetPerturbed.</span> + +<span class="sd"> Args:</span> +<span class="sd"> path: Path to the file(s) from which this `Dataset` should read.</span> +<span class="sd"> pulsemaps: Name(s) of the pulse map series that should be used to</span> +<span class="sd"> construct the nodes on the individual graph objects, and their</span> +<span class="sd"> features. Multiple pulse series maps can be used, e.g., when</span> +<span class="sd"> different DOM types are stored in different maps.</span> +<span class="sd"> features: List of columns in the input files that should be used as</span> +<span class="sd"> node features on the graph objects.</span> +<span class="sd"> truth: List of event-level columns in the input files that should</span> +<span class="sd"> be used added as attributes on the graph objects.</span> +<span class="sd"> perturbation_dict (Dict[str, float]): Dictionary mapping a feature</span> +<span class="sd"> name to a standard deviation according to which the values for</span> +<span class="sd"> this feature should be randomly perturbed.</span> +<span class="sd"> node_truth: List of node-level columns in the input files that</span> +<span class="sd"> should be used added as attributes on the graph objects.</span> +<span class="sd"> index_column: Name of the column in the input files that contains</span> +<span class="sd"> unique indicies to identify and map events across tables.</span> +<span class="sd"> truth_table: Name of the table containing event-level truth</span> +<span class="sd"> information.</span> +<span class="sd"> node_truth_table: Name of the table containing node-level truth</span> +<span class="sd"> information.</span> +<span class="sd"> string_selection: Subset of strings for which data should be read</span> +<span class="sd"> and used to construct graph objects. Defaults to None, meaning</span> +<span class="sd"> all strings for which data exists are used.</span> +<span class="sd"> selection: List of indicies (in `index_column`) of the events in</span> +<span class="sd"> the input files that should be read. Defaults to None, meaning</span> +<span class="sd"> that all events in the input files are read.</span> +<span class="sd"> dtype: Type of the feature tensor on the graph objects returned.</span> +<span class="sd"> loss_weight_table: Name of the table containing per-event loss</span> +<span class="sd"> weights.</span> +<span class="sd"> loss_weight_column: Name of the column in `loss_weight_table`</span> +<span class="sd"> containing per-event loss weights. This is also the name of the</span> +<span class="sd"> corresponding attribute assigned to the graph object.</span> +<span class="sd"> loss_weight_default_value: Default per-event loss weight.</span> +<span class="sd"> NOTE: This default value is only applied when</span> +<span class="sd"> `loss_weight_table` and `loss_weight_column` are specified, and</span> +<span class="sd"> in this case to events with no value in the corresponding</span> +<span class="sd"> table/column. That is, if no per-event loss weight table/column</span> +<span class="sd"> is provided, this value is ignored. Defaults to None.</span> +<span class="sd"> seed: Optional seed for random number generation. Defaults to None.</span> +<span class="sd"> """</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span> + <span class="n">path</span><span class="o">=</span><span class="n">path</span><span class="p">,</span> + <span class="n">pulsemaps</span><span class="o">=</span><span class="n">pulsemaps</span><span class="p">,</span> + <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">,</span> + <span class="n">truth</span><span class="o">=</span><span class="n">truth</span><span class="p">,</span> + <span class="n">node_truth</span><span class="o">=</span><span class="n">node_truth</span><span class="p">,</span> + <span class="n">index_column</span><span class="o">=</span><span class="n">index_column</span><span class="p">,</span> + <span class="n">truth_table</span><span class="o">=</span><span class="n">truth_table</span><span class="p">,</span> + <span class="n">node_truth_table</span><span class="o">=</span><span class="n">node_truth_table</span><span class="p">,</span> + <span class="n">string_selection</span><span class="o">=</span><span class="n">string_selection</span><span class="p">,</span> + <span class="n">selection</span><span class="o">=</span><span class="n">selection</span><span class="p">,</span> + <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> + <span class="n">loss_weight_table</span><span class="o">=</span><span class="n">loss_weight_table</span><span class="p">,</span> + <span class="n">loss_weight_column</span><span class="o">=</span><span class="n">loss_weight_column</span><span class="p">,</span> + <span class="n">loss_weight_default_value</span><span class="o">=</span><span class="n">loss_weight_default_value</span><span class="p">,</span> + <span class="p">)</span> + + <span class="c1"># Custom member variables</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">perturbation_dict</span><span class="p">,</span> <span class="nb">dict</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">perturbation_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span> + <span class="n">perturbation_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> + <span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_perturbation_dict</span> <span class="o">=</span> <span class="n">perturbation_dict</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_perturbation_cols</span> <span class="o">=</span> <span class="p">[</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_perturbation_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> + <span class="p">]</span> + + <span class="k">if</span> <span class="n">seed</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">seed</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span> + <span class="bp">self</span><span class="o">.</span><span class="n">rng</span> <span class="o">=</span> <span class="n">default_rng</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">seed</span><span class="p">,</span> <span class="n">Generator</span><span class="p">):</span> + <span class="bp">self</span><span class="o">.</span><span class="n">rng</span> <span class="o">=</span> <span class="n">seed</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> + <span class="s2">"Invalid seed. Must be an int or a numpy Generator."</span> + <span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">rng</span> <span class="o">=</span> <span class="n">default_rng</span><span class="p">()</span> + + <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sequential_index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return graph `Data` object at `index`."""</span> + <span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="mi">0</span> <span class="o"><=</span> <span class="n">sequential_index</span> <span class="o"><</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">)):</span> + <span class="k">raise</span> <span class="ne">IndexError</span><span class="p">(</span> + <span class="sa">f</span><span class="s2">"Index </span><span class="si">{</span><span class="n">sequential_index</span><span class="si">}</span><span class="s2"> not in range [0, </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="mi">1</span><span class="si">}</span><span class="s2">]"</span> + <span class="p">)</span> + <span class="n">features</span><span class="p">,</span> <span class="n">truth</span><span class="p">,</span> <span class="n">node_truth</span><span class="p">,</span> <span class="n">loss_weight</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_query</span><span class="p">(</span> + <span class="n">sequential_index</span> + <span class="p">)</span> + <span class="n">perturbed_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_perturb_features</span><span class="p">(</span><span class="n">features</span><span class="p">)</span> + <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_graph</span><span class="p">(</span> + <span class="n">perturbed_features</span><span class="p">,</span> <span class="n">truth</span><span class="p">,</span> <span class="n">node_truth</span><span class="p">,</span> <span class="n">loss_weight</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">graph</span> + + <span class="k">def</span> <span class="nf">_perturb_features</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="o">...</span><span class="p">]]</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="o">...</span><span class="p">]]:</span> + <span class="n">features_array</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">features</span><span class="p">)</span> + <span class="n">perturbed_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rng</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span> + <span class="n">loc</span><span class="o">=</span><span class="n">features_array</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_perturbation_cols</span><span class="p">],</span> + <span class="n">scale</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span> + <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_perturbation_dict</span><span class="o">.</span><span class="n">values</span><span class="p">()),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float</span> + <span class="p">),</span> + <span class="p">)</span> + <span class="n">features_array</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_perturbation_cols</span><span class="p">]</span> <span class="o">=</span> <span class="n">perturbed_features</span> + <span class="k">return</span> <span class="n">features_array</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/data/extractors/i3extractor.html b/_modules/graphnet/data/extractors/i3extractor.html index cae615098..0680557f4 100644 --- a/_modules/graphnet/data/extractors/i3extractor.html +++ b/_modules/graphnet/data/extractors/i3extractor.html @@ -464,7 +464,7 @@ <h1 id="modules-graphnet-data-extractors-i3extractor--page-root">Source code for </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/extractors/i3featureextractor.html b/_modules/graphnet/data/extractors/i3featureextractor.html index 06d9d3e68..de384a97e 100644 --- a/_modules/graphnet/data/extractors/i3featureextractor.html +++ b/_modules/graphnet/data/extractors/i3featureextractor.html @@ -647,7 +647,7 @@ <h1 id="modules-graphnet-data-extractors-i3featureextractor--page-root">Source c </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/extractors/i3genericextractor.html b/_modules/graphnet/data/extractors/i3genericextractor.html index 0324f9087..9c7d786f3 100644 --- a/_modules/graphnet/data/extractors/i3genericextractor.html +++ b/_modules/graphnet/data/extractors/i3genericextractor.html @@ -636,7 +636,7 @@ <h1 id="modules-graphnet-data-extractors-i3genericextractor--page-root">Source c </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/extractors/i3hybridrecoextractor.html b/_modules/graphnet/data/extractors/i3hybridrecoextractor.html index 21824f54b..ee80f730f 100644 --- a/_modules/graphnet/data/extractors/i3hybridrecoextractor.html +++ b/_modules/graphnet/data/extractors/i3hybridrecoextractor.html @@ -400,7 +400,7 @@ <h1 id="modules-graphnet-data-extractors-i3hybridrecoextractor--page-root">Sourc </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html b/_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html index 2a3a45023..d295f81a2 100644 --- a/_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html +++ b/_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html @@ -407,7 +407,7 @@ <h1 id="modules-graphnet-data-extractors-i3ntmuonlabelsextractor--page-root">Sou </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/extractors/i3particleextractor.html b/_modules/graphnet/data/extractors/i3particleextractor.html index 8d17ef0ee..60463b59a 100644 --- a/_modules/graphnet/data/extractors/i3particleextractor.html +++ b/_modules/graphnet/data/extractors/i3particleextractor.html @@ -392,7 +392,7 @@ <h1 id="modules-graphnet-data-extractors-i3particleextractor--page-root">Source </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/extractors/i3pisaextractor.html b/_modules/graphnet/data/extractors/i3pisaextractor.html index 635ba4537..0752dc39a 100644 --- a/_modules/graphnet/data/extractors/i3pisaextractor.html +++ b/_modules/graphnet/data/extractors/i3pisaextractor.html @@ -385,7 +385,7 @@ <h1 id="modules-graphnet-data-extractors-i3pisaextractor--page-root">Source code </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/extractors/i3quesoextractor.html b/_modules/graphnet/data/extractors/i3quesoextractor.html index bcac49e76..6ad9572a6 100644 --- a/_modules/graphnet/data/extractors/i3quesoextractor.html +++ b/_modules/graphnet/data/extractors/i3quesoextractor.html @@ -395,7 +395,7 @@ <h1 id="modules-graphnet-data-extractors-i3quesoextractor--page-root">Source cod </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/extractors/i3retroextractor.html b/_modules/graphnet/data/extractors/i3retroextractor.html index dea96d47a..b76bed436 100644 --- a/_modules/graphnet/data/extractors/i3retroextractor.html +++ b/_modules/graphnet/data/extractors/i3retroextractor.html @@ -467,7 +467,7 @@ <h1 id="modules-graphnet-data-extractors-i3retroextractor--page-root">Source cod </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/extractors/i3splinempeextractor.html b/_modules/graphnet/data/extractors/i3splinempeextractor.html index 79f5f9457..dd8d383bd 100644 --- a/_modules/graphnet/data/extractors/i3splinempeextractor.html +++ b/_modules/graphnet/data/extractors/i3splinempeextractor.html @@ -379,7 +379,7 @@ <h1 id="modules-graphnet-data-extractors-i3splinempeextractor--page-root">Source </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/extractors/i3truthextractor.html b/_modules/graphnet/data/extractors/i3truthextractor.html index 330ad508b..19d005d0d 100644 --- a/_modules/graphnet/data/extractors/i3truthextractor.html +++ b/_modules/graphnet/data/extractors/i3truthextractor.html @@ -781,7 +781,7 @@ <h1 id="modules-graphnet-data-extractors-i3truthextractor--page-root">Source cod </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/extractors/i3tumextractor.html b/_modules/graphnet/data/extractors/i3tumextractor.html index d22868f07..89adb77c2 100644 --- a/_modules/graphnet/data/extractors/i3tumextractor.html +++ b/_modules/graphnet/data/extractors/i3tumextractor.html @@ -382,7 +382,7 @@ <h1 id="modules-graphnet-data-extractors-i3tumextractor--page-root">Source code </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/extractors/utilities/collections.html b/_modules/graphnet/data/extractors/utilities/collections.html index 11ccbd8f5..14835acc6 100644 --- a/_modules/graphnet/data/extractors/utilities/collections.html +++ b/_modules/graphnet/data/extractors/utilities/collections.html @@ -436,7 +436,7 @@ <h1 id="modules-graphnet-data-extractors-utilities-collections--page-root">Sourc </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/extractors/utilities/frames.html b/_modules/graphnet/data/extractors/utilities/frames.html index 4dfe93249..656cb6033 100644 --- a/_modules/graphnet/data/extractors/utilities/frames.html +++ b/_modules/graphnet/data/extractors/utilities/frames.html @@ -439,7 +439,7 @@ <h1 id="modules-graphnet-data-extractors-utilities-frames--page-root">Source cod </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/extractors/utilities/types.html b/_modules/graphnet/data/extractors/utilities/types.html index f2d4ecd65..60de6d341 100644 --- a/_modules/graphnet/data/extractors/utilities/types.html +++ b/_modules/graphnet/data/extractors/utilities/types.html @@ -660,7 +660,7 @@ <h1 id="modules-graphnet-data-extractors-utilities-types--page-root">Source code </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/parquet/parquet_dataconverter.html b/_modules/graphnet/data/parquet/parquet_dataconverter.html index 7cf5c3507..455153448 100644 --- a/_modules/graphnet/data/parquet/parquet_dataconverter.html +++ b/_modules/graphnet/data/parquet/parquet_dataconverter.html @@ -407,7 +407,7 @@ <h1 id="modules-graphnet-data-parquet-parquet-dataconverter--page-root">Source c </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/pipeline.html b/_modules/graphnet/data/pipeline.html new file mode 100644 index 000000000..90602e63c --- /dev/null +++ b/_modules/graphnet/data/pipeline.html @@ -0,0 +1,593 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.data.pipeline — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" /> + <script src="../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../about.html" /> + <link rel="index" title="Index" href="../../../genindex.html" /> + <link rel="search" title="Search" href="../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/data/pipeline" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.data.pipeline </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../"versions.json"", + target_loc = "../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-data-pipeline--page-root">Source code for graphnet.data.pipeline</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Class(es) used for analysis in PISA."""</span> + +<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span> +<span class="kn">import</span> <span class="nn">dill</span> +<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">reduce</span> +<span class="kn">import</span> <span class="nn">os</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span> + +<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> +<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> +<span class="kn">from</span> <span class="nn">pytorch_lightning</span> <span class="kn">import</span> <span class="n">Trainer</span> +<span class="kn">import</span> <span class="nn">sqlite3</span> +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span> + +<span class="kn">from</span> <span class="nn">graphnet.data.sqlite.sqlite_utilities</span> <span class="kn">import</span> <span class="n">create_table_and_save_to_sql</span> +<span class="kn">from</span> <span class="nn">graphnet.training.utils</span> <span class="kn">import</span> <span class="n">get_predictions</span><span class="p">,</span> <span class="n">make_dataloader</span> +<span class="kn">from</span> <span class="nn">graphnet.models.graphs</span> <span class="kn">import</span> <span class="n">GraphDefinition</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span> + + +<div class="viewcode-block" id="InSQLitePipeline"> +<a class="viewcode-back" href="../../../api/graphnet.data.pipeline.html#graphnet.data.pipeline.InSQLitePipeline">[docs]</a> +<span class="k">class</span> <span class="nc">InSQLitePipeline</span><span class="p">(</span><span class="n">ABC</span><span class="p">,</span> <span class="n">Logger</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Create a SQLite database for PISA analysis.</span> + +<span class="sd"> The database will contain truth and GNN predictions and, if available,</span> +<span class="sd"> RETRO reconstructions.</span> +<span class="sd"> """</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">module_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">,</span> + <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="n">truth</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> + <span class="n">retro_table_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"retro"</span><span class="p">,</span> + <span class="n">outdir</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span> + <span class="n">n_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> + <span class="n">pipeline_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"pipeline"</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Initialise the pipeline.</span> + +<span class="sd"> Args:</span> +<span class="sd"> module_dict: A dictionary with GNN modules from GraphNet. E.g.</span> +<span class="sd"> {'energy': gnn_module_for_energy_regression}</span> +<span class="sd"> features: List of input features for the GNN modules.</span> +<span class="sd"> truth: List of truth for the GNN ModuleList.</span> +<span class="sd"> device: The device used for computation.</span> +<span class="sd"> retro_table_name: Name of the retro table for.</span> +<span class="sd"> outdir: the directory in which the pipeline database will be</span> +<span class="sd"> stored.</span> +<span class="sd"> batch_size: Batch size for inference.</span> +<span class="sd"> n_workers: Number of workers used in dataloading.</span> +<span class="sd"> pipeline_name: Name of the pipeline. If such a pipeline already</span> +<span class="sd"> exists, an error will be prompted to avoid overwriting.</span> +<span class="sd"> """</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_pipeline_name</span> <span class="o">=</span> <span class="n">pipeline_name</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_device</span> <span class="o">=</span> <span class="n">device</span> + <span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span> <span class="o">=</span> <span class="n">n_workers</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_features</span> <span class="o">=</span> <span class="n">features</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_truth</span> <span class="o">=</span> <span class="n">truth</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_batch_size</span> <span class="o">=</span> <span class="n">batch_size</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_outdir</span> <span class="o">=</span> <span class="n">outdir</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_module_dict</span> <span class="o">=</span> <span class="n">module_dict</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_retro_table_name</span> <span class="o">=</span> <span class="n">retro_table_name</span> + + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span> + + <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">database</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">pulsemap</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">graph_definition</span><span class="p">:</span> <span class="n">GraphDefinition</span><span class="p">,</span> + <span class="n">chunk_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1000000</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Run inference of each field in self._module_dict[target][''].</span> + +<span class="sd"> Args:</span> +<span class="sd"> database: Path to database with pulsemap and truth.</span> +<span class="sd"> pulsemap: Name of pulsemaps.</span> +<span class="sd"> graph_definition: GraphDefinition for Dataset</span> +<span class="sd"> chunk_size: database will be sliced in chunks of size `chunk_size`.</span> +<span class="sd"> Use this parameter to control memory usage.</span> +<span class="sd"> """</span> + <span class="n">outdir</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_outdir</span><span class="p">(</span><span class="n">database</span><span class="p">)</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_device</span><span class="p">,</span> <span class="nb">str</span> + <span class="p">):</span> <span class="c1"># Because pytorch lightning insists on breaking pytorch cuda device naming scheme</span> + <span class="n">device</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_device</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> + <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isdir</span><span class="p">(</span><span class="n">outdir</span><span class="p">):</span> + <span class="n">dataloaders</span><span class="p">,</span> <span class="n">event_batches</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_setup_dataloaders</span><span class="p">(</span> + <span class="n">graph_definition</span><span class="o">=</span><span class="n">graph_definition</span><span class="p">,</span> + <span class="n">chunk_size</span><span class="o">=</span><span class="n">chunk_size</span><span class="p">,</span> + <span class="n">db</span><span class="o">=</span><span class="n">database</span><span class="p">,</span> + <span class="n">pulsemap</span><span class="o">=</span><span class="n">pulsemap</span><span class="p">,</span> + <span class="n">selection</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> + <span class="n">persistent_workers</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> + <span class="p">)</span> + <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span> + <span class="k">for</span> <span class="n">dataloader</span> <span class="ow">in</span> <span class="n">dataloaders</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">"CHUNK </span><span class="si">%s</span><span class="s2"> / </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">dataloaders</span><span class="p">)))</span> + <span class="n">df</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_inference</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">dataloader</span><span class="p">)</span> + <span class="n">truth</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_truth</span><span class="p">(</span><span class="n">database</span><span class="p">,</span> <span class="n">event_batches</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span> + <span class="n">retro</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_retro</span><span class="p">(</span><span class="n">database</span><span class="p">,</span> <span class="n">event_batches</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_append_to_pipeline</span><span class="p">(</span><span class="n">outdir</span><span class="p">,</span> <span class="n">truth</span><span class="p">,</span> <span class="n">retro</span><span class="p">,</span> <span class="n">df</span><span class="p">)</span> + <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span> + <span class="k">else</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="n">outdir</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span> + <span class="s2">"WARNING - Pipeline named </span><span class="si">%s</span><span class="s2"> already exists! </span><span class="se">\n</span><span class="s2"> Please rename pipeline!"</span> + <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pipeline_name</span> + <span class="p">)</span> + + <span class="k">def</span> <span class="nf">_setup_dataloaders</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">chunk_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> + <span class="n">db</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">pulsemap</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">graph_definition</span><span class="p">:</span> <span class="n">GraphDefinition</span><span class="p">,</span> + <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">persistent_workers</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">DataLoader</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]]:</span> + <span class="k">if</span> <span class="n">selection</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">selection</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_all_event_nos</span><span class="p">(</span><span class="n">db</span><span class="p">)</span> + <span class="n">n_chunks</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">selection</span><span class="p">)</span> <span class="o">/</span> <span class="n">chunk_size</span><span class="p">)</span> + <span class="n">event_batches</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array_split</span><span class="p">(</span><span class="n">selection</span><span class="p">,</span> <span class="n">n_chunks</span><span class="p">)</span> + <span class="n">dataloaders</span> <span class="o">=</span> <span class="p">[]</span> + <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">event_batches</span><span class="p">:</span> + <span class="n">dataloaders</span><span class="o">.</span><span class="n">append</span><span class="p">(</span> + <span class="n">make_dataloader</span><span class="p">(</span> + <span class="n">db</span><span class="o">=</span><span class="n">db</span><span class="p">,</span> + <span class="n">graph_definition</span><span class="o">=</span><span class="n">graph_definition</span><span class="p">,</span> + <span class="n">pulsemaps</span><span class="o">=</span><span class="n">pulsemap</span><span class="p">,</span> + <span class="n">features</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">,</span> + <span class="n">truth</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_truth</span><span class="p">,</span> + <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_batch_size</span><span class="p">,</span> + <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> + <span class="n">selection</span><span class="o">=</span><span class="n">batch</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span> + <span class="n">num_workers</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">n_workers</span><span class="p">,</span> + <span class="n">persistent_workers</span><span class="o">=</span><span class="n">persistent_workers</span><span class="p">,</span> + <span class="p">)</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">dataloaders</span><span class="p">,</span> <span class="n">event_batches</span> + + <span class="k">def</span> <span class="nf">_get_all_event_nos</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">db</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span> + <span class="k">with</span> <span class="n">sqlite3</span><span class="o">.</span><span class="n">connect</span><span class="p">(</span><span class="n">db</span><span class="p">)</span> <span class="k">as</span> <span class="n">con</span><span class="p">:</span> + <span class="n">query</span> <span class="o">=</span> <span class="s2">"SELECT event_no FROM truth"</span> + <span class="n">selection</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_sql</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">con</span><span class="p">)</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> + <span class="k">return</span> <span class="n">selection</span> + + <span class="k">def</span> <span class="nf">_combine_outputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataframes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">])</span> <span class="o">-></span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">:</span> + <span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">pd</span><span class="o">.</span><span class="n">merge</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">on</span><span class="o">=</span><span class="s2">"event_no"</span><span class="p">),</span> <span class="n">dataframes</span><span class="p">)</span> + + <span class="k">def</span> <span class="nf">_inference</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">:</span> + <span class="n">dataframes</span> <span class="o">=</span> <span class="p">[]</span> + <span class="k">for</span> <span class="n">target</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_module_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span> + <span class="c1"># dataloader = iter(dataloader)</span> + <span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="n">devices</span><span class="o">=</span><span class="p">[</span><span class="n">device</span><span class="p">],</span> <span class="n">accelerator</span><span class="o">=</span><span class="s2">"gpu"</span><span class="p">)</span> + <span class="n">model</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_module_dict</span><span class="p">[</span><span class="n">target</span><span class="p">][</span><span class="s2">"path"</span><span class="p">],</span> + <span class="n">map_location</span><span class="o">=</span><span class="s2">"cpu"</span><span class="p">,</span> + <span class="n">pickle_module</span><span class="o">=</span><span class="n">dill</span><span class="p">,</span> + <span class="p">)</span> + <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span> + <span class="n">model</span><span class="o">.</span><span class="n">inference</span><span class="p">()</span> + <span class="n">results</span> <span class="o">=</span> <span class="n">get_predictions</span><span class="p">(</span> + <span class="n">trainer</span><span class="p">,</span> + <span class="n">model</span><span class="p">,</span> + <span class="n">dataloader</span><span class="p">,</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_module_dict</span><span class="p">[</span><span class="n">target</span><span class="p">][</span><span class="s2">"output_column_names"</span><span class="p">],</span> + <span class="n">additional_attributes</span><span class="o">=</span><span class="p">[</span><span class="s2">"event_no"</span><span class="p">],</span> + <span class="p">)</span> + <span class="n">dataframes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span> + <span class="n">results</span><span class="o">.</span><span class="n">sort_values</span><span class="p">(</span><span class="s2">"event_no"</span><span class="p">)</span><span class="o">.</span><span class="n">reset_index</span><span class="p">(</span><span class="n">drop</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> + <span class="p">)</span> + <span class="n">df</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_combine_outputs</span><span class="p">(</span><span class="n">dataframes</span><span class="p">)</span> + <span class="k">return</span> <span class="n">df</span> + + <span class="k">def</span> <span class="nf">_get_outdir</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">database</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_outdir</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">database_name</span> <span class="o">=</span> <span class="n">database</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"/"</span><span class="p">)[</span><span class="o">-</span><span class="mi">3</span><span class="p">]</span> + <span class="n">outdir</span> <span class="o">=</span> <span class="p">(</span> + <span class="n">database</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">database_name</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> + <span class="o">+</span> <span class="n">database_name</span> + <span class="o">+</span> <span class="s2">"/pipelines/"</span> + <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pipeline_name</span> + <span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">outdir</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_outdir</span> + <span class="k">return</span> <span class="n">outdir</span> + + <span class="k">def</span> <span class="nf">_get_truth</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">database</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">selection</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">:</span> + <span class="k">with</span> <span class="n">sqlite3</span><span class="o">.</span><span class="n">connect</span><span class="p">(</span><span class="n">database</span><span class="p">)</span> <span class="k">as</span> <span class="n">con</span><span class="p">:</span> + <span class="n">query</span> <span class="o">=</span> <span class="s2">"SELECT * FROM truth WHERE event_no in </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span> + <span class="nb">tuple</span><span class="p">(</span><span class="n">selection</span><span class="p">)</span> + <span class="p">)</span> + <span class="n">truth</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_sql</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">con</span><span class="p">)</span> + <span class="k">return</span> <span class="n">truth</span> + + <span class="k">def</span> <span class="nf">_get_retro</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">database</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">selection</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">:</span> + <span class="k">try</span><span class="p">:</span> + <span class="k">with</span> <span class="n">sqlite3</span><span class="o">.</span><span class="n">connect</span><span class="p">(</span><span class="n">database</span><span class="p">)</span> <span class="k">as</span> <span class="n">con</span><span class="p">:</span> + <span class="n">query</span> <span class="o">=</span> <span class="s2">"SELECT * FROM </span><span class="si">%s</span><span class="s2"> WHERE event_no in </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_retro_table_name</span><span class="p">,</span> + <span class="nb">str</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">selection</span><span class="p">)),</span> + <span class="p">)</span> + <span class="n">retro</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_sql</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">con</span><span class="p">)</span> + <span class="k">return</span> <span class="n">retro</span> + <span class="k">except</span><span class="p">:</span> <span class="c1"># noqa: E722</span> + <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">"</span><span class="si">%s</span><span class="s2"> table does not exist"</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">_retro_table_name</span><span class="p">)</span> + + <span class="k">def</span> <span class="nf">_append_to_pipeline</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">outdir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">truth</span><span class="p">:</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">,</span> + <span class="n">retro</span><span class="p">:</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">,</span> + <span class="n">df</span><span class="p">:</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> + <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">outdir</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> + <span class="n">pipeline_database</span> <span class="o">=</span> <span class="n">outdir</span> <span class="o">+</span> <span class="s2">"/</span><span class="si">%s</span><span class="s2">.db"</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pipeline_name</span> + <span class="n">create_table_and_save_to_sql</span><span class="p">(</span><span class="n">df</span><span class="p">,</span> <span class="s2">"reconstruction"</span><span class="p">,</span> <span class="n">pipeline_database</span><span class="p">)</span> + <span class="n">create_table_and_save_to_sql</span><span class="p">(</span><span class="n">truth</span><span class="p">,</span> <span class="s2">"truth"</span><span class="p">,</span> <span class="n">pipeline_database</span><span class="p">)</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">retro</span><span class="p">,</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">):</span> + <span class="n">create_table_and_save_to_sql</span><span class="p">(</span> + <span class="n">retro</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_retro_table_name</span><span class="p">,</span> <span class="n">pipeline_database</span> + <span class="p">)</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/data/sqlite/sqlite_dataconverter.html b/_modules/graphnet/data/sqlite/sqlite_dataconverter.html index a067a451a..69fa21bbe 100644 --- a/_modules/graphnet/data/sqlite/sqlite_dataconverter.html +++ b/_modules/graphnet/data/sqlite/sqlite_dataconverter.html @@ -716,7 +716,7 @@ <h1 id="modules-graphnet-data-sqlite-sqlite-dataconverter--page-root">Source cod </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/sqlite/sqlite_utilities.html b/_modules/graphnet/data/sqlite/sqlite_utilities.html index 29d4e537d..a910e007d 100644 --- a/_modules/graphnet/data/sqlite/sqlite_utilities.html +++ b/_modules/graphnet/data/sqlite/sqlite_utilities.html @@ -515,7 +515,7 @@ <h1 id="modules-graphnet-data-sqlite-sqlite-utilities--page-root">Source code fo </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/utilities/parquet_to_sqlite.html b/_modules/graphnet/data/utilities/parquet_to_sqlite.html index e645a35f4..f3944caa0 100644 --- a/_modules/graphnet/data/utilities/parquet_to_sqlite.html +++ b/_modules/graphnet/data/utilities/parquet_to_sqlite.html @@ -533,7 +533,7 @@ <h1 id="modules-graphnet-data-utilities-parquet-to-sqlite--page-root">Source cod </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/utilities/random.html b/_modules/graphnet/data/utilities/random.html index f297fc67c..eb3a7f85c 100644 --- a/_modules/graphnet/data/utilities/random.html +++ b/_modules/graphnet/data/utilities/random.html @@ -375,7 +375,7 @@ <h1 id="modules-graphnet-data-utilities-random--page-root">Source code for graph </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/data/utilities/string_selection_resolver.html b/_modules/graphnet/data/utilities/string_selection_resolver.html index bd7909f33..361902f54 100644 --- a/_modules/graphnet/data/utilities/string_selection_resolver.html +++ b/_modules/graphnet/data/utilities/string_selection_resolver.html @@ -676,7 +676,7 @@ <h1 id="modules-graphnet-data-utilities-string-selection-resolver--page-root">So </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/deployment/i3modules/graphnet_module.html b/_modules/graphnet/deployment/i3modules/graphnet_module.html new file mode 100644 index 000000000..994461c04 --- /dev/null +++ b/_modules/graphnet/deployment/i3modules/graphnet_module.html @@ -0,0 +1,817 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.deployment.i3modules.graphnet_module — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/deployment/i3modules/graphnet_module" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.deployment.i3modules.graphnet_module </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-deployment-i3modules-graphnet-module--page-root">Source code for graphnet.deployment.i3modules.graphnet_module</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Class(es) for deploying GraphNeT models in icetray as I3Modules."""</span> +<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">TYPE_CHECKING</span><span class="p">,</span> <span class="n">Any</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Optional</span> + +<span class="kn">import</span> <span class="nn">dill</span> +<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span> + +<span class="kn">from</span> <span class="nn">graphnet.data.extractors</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">I3FeatureExtractor</span><span class="p">,</span> + <span class="n">I3FeatureExtractorIceCubeUpgrade</span><span class="p">,</span> +<span class="p">)</span> +<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span><span class="p">,</span> <span class="n">StandardModel</span> +<span class="kn">from</span> <span class="nn">graphnet.models.graphs</span> <span class="kn">import</span> <span class="n">GraphDefinition</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.imports</span> <span class="kn">import</span> <span class="n">has_icecube_package</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">ModelConfig</span> + +<span class="k">if</span> <span class="n">has_icecube_package</span><span class="p">()</span> <span class="ow">or</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span> + <span class="kn">from</span> <span class="nn">icecube.icetray</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">I3Module</span><span class="p">,</span> + <span class="n">I3Frame</span><span class="p">,</span> + <span class="p">)</span> <span class="c1"># pyright: reportMissingImports=false</span> + <span class="kn">from</span> <span class="nn">icecube.dataclasses</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">I3Double</span><span class="p">,</span> + <span class="n">I3MapKeyVectorDouble</span><span class="p">,</span> + <span class="p">)</span> <span class="c1"># pyright: reportMissingImports=false</span> + <span class="kn">from</span> <span class="nn">icecube</span> <span class="kn">import</span> <span class="n">dataclasses</span><span class="p">,</span> <span class="n">dataio</span><span class="p">,</span> <span class="n">icetray</span> + + +<div class="viewcode-block" id="GraphNeTI3Module"> +<a class="viewcode-back" href="../../../../api/graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module">[docs]</a> +<span class="k">class</span> <span class="nc">GraphNeTI3Module</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Base I3 Module for GraphNeT.</span> + +<span class="sd"> Contains methods for extracting pulsemaps, producing graphs and writing to</span> +<span class="sd"> frames.</span> +<span class="sd"> """</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">graph_definition</span><span class="p">:</span> <span class="n">GraphDefinition</span><span class="p">,</span> + <span class="n">pulsemap</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="n">pulsemap_extractor</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span> + <span class="n">List</span><span class="p">[</span><span class="n">I3FeatureExtractor</span><span class="p">],</span> <span class="n">I3FeatureExtractor</span> + <span class="p">],</span> + <span class="n">gcd_file</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""I3Module Constructor.</span> + +<span class="sd"> Arguments:</span> +<span class="sd"> graph_definition: An instance of GraphDefinition. E.g. KNNGraph.</span> +<span class="sd"> pulsemap: the pulse map on which the module functions</span> +<span class="sd"> features: the features that is used from the pulse map.</span> +<span class="sd"> E.g. [dom_x, dom_y, dom_z, charge]</span> +<span class="sd"> pulsemap_extractor: The I3FeatureExtractor used to extract the</span> +<span class="sd"> pulsemap from the I3Frames</span> +<span class="sd"> gcd_file: Path to the associated gcd-file.</span> +<span class="sd"> """</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">graph_definition</span><span class="p">,</span> <span class="n">GraphDefinition</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_graph_definition</span> <span class="o">=</span> <span class="n">graph_definition</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_pulsemap</span> <span class="o">=</span> <span class="n">pulsemap</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_features</span> <span class="o">=</span> <span class="n">features</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">gcd_file</span><span class="p">,</span> <span class="nb">str</span><span class="p">),</span> <span class="s2">"gcd_file must be string"</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_gcd_file</span> <span class="o">=</span> <span class="n">gcd_file</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pulsemap_extractor</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_i3_extractors</span> <span class="o">=</span> <span class="n">pulsemap_extractor</span> + <span class="k">else</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_i3_extractors</span> <span class="o">=</span> <span class="p">[</span><span class="n">pulsemap_extractor</span><span class="p">]</span> + + <span class="k">for</span> <span class="n">i3_extractor</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_i3_extractors</span><span class="p">:</span> + <span class="n">i3_extractor</span><span class="o">.</span><span class="n">set_files</span><span class="p">(</span><span class="n">i3_file</span><span class="o">=</span><span class="s2">""</span><span class="p">,</span> <span class="n">gcd_file</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_gcd_file</span><span class="p">)</span> + + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Define here how the module acts on the frame.</span> + +<span class="sd"> Must return True if successful.</span> + +<span class="sd"> Return True # SUPER IMPORTANT</span> +<span class="sd"> """</span> + + <span class="k">def</span> <span class="nf">_make_graph</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> <span class="c1"># py-l-i-n-t-:- -d-i-s-able=invalid-name</span> +<span class="w"> </span><span class="sd">"""Process Physics I3Frame into graph."""</span> + <span class="c1"># Extract features</span> + <span class="n">node_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_extract_feature_array_from_frame</span><span class="p">(</span><span class="n">frame</span><span class="p">)</span> + <span class="c1"># Prepare graph data</span> + <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">node_features</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span> + <span class="n">data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_graph_definition</span><span class="p">(</span> + <span class="n">node_features</span><span class="o">=</span><span class="n">node_features</span><span class="p">,</span> + <span class="n">node_feature_names</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">,</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">Batch</span><span class="o">.</span><span class="n">from_data_list</span><span class="p">([</span><span class="n">data</span><span class="p">])</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">return</span> <span class="kc">None</span> + + <span class="k">def</span> <span class="nf">_extract_feature_array_from_frame</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span><span class="p">)</span> <span class="o">-></span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Apply the I3FeatureExtractors to the I3Frame.</span> + +<span class="sd"> Arguments:</span> +<span class="sd"> frame: Physics I3Frame (PFrame)</span> + +<span class="sd"> Returns:</span> +<span class="sd"> array with pulsemap</span> +<span class="sd"> """</span> + <span class="n">features</span> <span class="o">=</span> <span class="kc">None</span> + <span class="k">for</span> <span class="n">i3extractor</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_i3_extractors</span><span class="p">:</span> + <span class="n">feature_dict</span> <span class="o">=</span> <span class="n">i3extractor</span><span class="p">(</span><span class="n">frame</span><span class="p">)</span> + <span class="n">features_pulsemap</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span> + <span class="p">[</span><span class="n">feature_dict</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_features</span><span class="p">]</span> + <span class="p">)</span><span class="o">.</span><span class="n">T</span> + <span class="k">if</span> <span class="n">features</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">features</span> <span class="o">=</span> <span class="n">features_pulsemap</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">features</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span> + <span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">features_pulsemap</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">features</span> + + <span class="k">def</span> <span class="nf">_add_to_frame</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-></span> <span class="n">I3Frame</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Add every field in data to I3Frame.</span> + +<span class="sd"> Arguments:</span> +<span class="sd"> frame: I3Frame (physics)</span> +<span class="sd"> data: Dictionary containing content that will be written to frame.</span> + +<span class="sd"> Returns:</span> +<span class="sd"> frame: Same I3Frame as input, but with the new entries</span> +<span class="sd"> """</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span> + <span class="n">data</span><span class="p">,</span> <span class="nb">dict</span> + <span class="p">),</span> <span class="sa">f</span><span class="s2">"data must be of type dict. Got </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">data</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span> + <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">data</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span> + <span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frame</span><span class="p">:</span> + <span class="n">frame</span><span class="o">.</span><span class="n">Put</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="n">key</span><span class="p">])</span> + <span class="k">return</span> <span class="n">frame</span></div> + + + +<div class="viewcode-block" id="I3InferenceModule"> +<a class="viewcode-back" href="../../../../api/graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule">[docs]</a> +<span class="k">class</span> <span class="nc">I3InferenceModule</span><span class="p">(</span><span class="n">GraphNeTI3Module</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""General class for inference on i3 frames."""</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">pulsemap</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="n">pulsemap_extractor</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span> + <span class="n">List</span><span class="p">[</span><span class="n">I3FeatureExtractor</span><span class="p">],</span> <span class="n">I3FeatureExtractor</span> + <span class="p">],</span> + <span class="n">model_config</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">ModelConfig</span><span class="p">,</span> <span class="nb">str</span><span class="p">],</span> + <span class="n">state_dict</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">model_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">gcd_file</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">prediction_columns</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""General class for inference on I3Frames (physics).</span> + +<span class="sd"> Arguments:</span> +<span class="sd"> pulsemap: the pulsmap that the model is expecting as input.</span> +<span class="sd"> features: the features of the pulsemap that the model is expecting.</span> +<span class="sd"> pulsemap_extractor: The extractor used to extract the pulsemap.</span> +<span class="sd"> model_config: The ModelConfig (or path to it) that summarizes the</span> +<span class="sd"> model used for inference.</span> +<span class="sd"> state_dict: Path to state_dict containing the learned weights.</span> +<span class="sd"> model_name: The name used for the model. Will help define the</span> +<span class="sd"> named entry in the I3Frame. E.g. "dynedge".</span> +<span class="sd"> gcd_file: path to associated gcd file.</span> +<span class="sd"> prediction_columns: column names for the predictions of the model.</span> +<span class="sd"> Will help define the named entry in the I3Frame.</span> +<span class="sd"> E.g. ['energy_reco']. Optional.</span> +<span class="sd"> """</span> + <span class="c1"># Construct model & load weights</span> + <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span> <span class="n">trust</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="p">)</span> + + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span> + <span class="n">pulsemap</span><span class="o">=</span><span class="n">pulsemap</span><span class="p">,</span> + <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">,</span> + <span class="n">pulsemap_extractor</span><span class="o">=</span><span class="n">pulsemap_extractor</span><span class="p">,</span> + <span class="n">gcd_file</span><span class="o">=</span><span class="n">gcd_file</span><span class="p">,</span> + <span class="n">graph_definition</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">_graph_definition</span><span class="p">,</span> + <span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">inference</span><span class="p">()</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">"cpu"</span><span class="p">)</span> + <span class="k">if</span> <span class="n">prediction_columns</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">prediction_columns</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span> + <span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span> <span class="o">=</span> <span class="p">[</span><span class="n">prediction_columns</span><span class="p">]</span> + <span class="k">else</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span> <span class="o">=</span> <span class="n">prediction_columns</span> + <span class="k">else</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">prediction_labels</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">model_name</span> <span class="o">=</span> <span class="n">model_name</span> + + <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Write predictions from model to frame."""</span> + <span class="c1"># inference</span> + <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_make_graph</span><span class="p">(</span><span class="n">frame</span><span class="p">)</span> + <span class="k">if</span> <span class="n">graph</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">predictions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_inference</span><span class="p">(</span><span class="n">graph</span><span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">predictions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span> + <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">],</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span><span class="p">)</span> + <span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span><span class="p">))</span> + <span class="c1"># Check dimensions of predictions and prediction columns</span> + <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span> + <span class="n">dim</span> <span class="o">=</span> <span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">dim</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> + <span class="k">assert</span> <span class="n">dim</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span> + <span class="p">),</span> <span class="sa">f</span><span class="s2">"""predictions have shape </span><span class="si">{</span><span class="n">dim</span><span class="si">}</span><span class="s2"> but </span><span class="se">\n</span> +<span class="s2"> prediction columns have [</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span><span class="si">}</span><span class="s2">]"""</span> + + <span class="c1"># Build Dictionary of predictions</span> + <span class="n">data</span> <span class="o">=</span> <span class="p">{}</span> + <span class="k">assert</span> <span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> + <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">dim</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim</span><span class="p">)):</span> + <span class="k">try</span><span class="p">:</span> + <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions</span><span class="p">[:,</span> <span class="n">i</span><span class="p">])</span> <span class="o">==</span> <span class="mi">1</span> + <span class="n">data</span><span class="p">[</span> + <span class="bp">self</span><span class="o">.</span><span class="n">model_name</span> <span class="o">+</span> <span class="s2">"_"</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> + <span class="p">]</span> <span class="o">=</span> <span class="n">I3Double</span><span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="n">predictions</span><span class="p">[:,</span> <span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]))</span> + <span class="k">except</span> <span class="ne">IndexError</span><span class="p">:</span> + <span class="n">data</span><span class="p">[</span> + <span class="bp">self</span><span class="o">.</span><span class="n">model_name</span> <span class="o">+</span> <span class="s2">"_"</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">prediction_columns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> + <span class="p">]</span> <span class="o">=</span> <span class="n">I3Double</span><span class="p">(</span><span class="n">predictions</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> + + <span class="c1"># Submission methods</span> + <span class="n">frame</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_to_frame</span><span class="p">(</span><span class="n">frame</span><span class="o">=</span><span class="n">frame</span><span class="p">,</span> <span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">)</span> + <span class="k">return</span> <span class="kc">True</span> + + <span class="k">def</span> <span class="nf">_inference</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span> + <span class="c1"># Perform inference</span> + <span class="n">task_predictions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> + <span class="k">assert</span> <span class="p">(</span> + <span class="nb">len</span><span class="p">(</span><span class="n">task_predictions</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span> + <span class="p">),</span> <span class="sa">f</span><span class="s2">"""This method assumes a single task. </span><span class="se">\n</span> +<span class="s2"> Got </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">task_predictions</span><span class="p">)</span><span class="si">}</span><span class="s2"> tasks."""</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span></div> + + + +<div class="viewcode-block" id="I3PulseCleanerModule"> +<a class="viewcode-back" href="../../../../api/graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule">[docs]</a> +<span class="k">class</span> <span class="nc">I3PulseCleanerModule</span><span class="p">(</span><span class="n">I3InferenceModule</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""A specialized module for pulse cleaning.</span> + +<span class="sd"> It is assumed that the model provided has been trained for this.</span> +<span class="sd"> """</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">pulsemap</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="n">pulsemap_extractor</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span> + <span class="n">List</span><span class="p">[</span><span class="n">I3FeatureExtractor</span><span class="p">],</span> <span class="n">I3FeatureExtractor</span> + <span class="p">],</span> + <span class="n">model_config</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">state_dict</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">model_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="o">*</span><span class="p">,</span> + <span class="n">gcd_file</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.7</span><span class="p">,</span> + <span class="n">discard_empty_events</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> + <span class="n">prediction_columns</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""General class for inference on I3Frames (physics).</span> + +<span class="sd"> Arguments:</span> +<span class="sd"> pulsemap: the pulsmap that the model is expecting as input</span> +<span class="sd"> (the one that is being cleaned).</span> +<span class="sd"> features: the features of the pulsemap that the model is expecting.</span> +<span class="sd"> pulsemap_extractor: The extractor used to extract the pulsemap.</span> +<span class="sd"> model_config: The ModelConfig (or path to it) that summarizes the</span> +<span class="sd"> model used for inference.</span> +<span class="sd"> state_dict: Path to state_dict containing the learned weights.</span> +<span class="sd"> model_name: The name used for the model. Will help define the named</span> +<span class="sd"> entry in the I3Frame. E.g. "dynedge".</span> +<span class="sd"> gcd_file: path to associated gcd file.</span> +<span class="sd"> threshold: the threshold for being considered a positive case.</span> +<span class="sd"> E.g., predictions >= threshold will be considered</span> +<span class="sd"> to be signal, all else noise.</span> +<span class="sd"> discard_empty_events: When true, this flag will eliminate events</span> +<span class="sd"> whose cleaned pulse series are empty. Can be used</span> +<span class="sd"> to speed up processing especially for noise</span> +<span class="sd"> simulation, since it will not do any writing or</span> +<span class="sd"> further calculations.</span> +<span class="sd"> prediction_columns: column names for the predictions of the model.</span> +<span class="sd"> Will help define the named entry in the I3Frame.</span> +<span class="sd"> E.g. ['energy_reco']. Optional.</span> +<span class="sd"> """</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span> + <span class="n">pulsemap</span><span class="o">=</span><span class="n">pulsemap</span><span class="p">,</span> + <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">,</span> + <span class="n">pulsemap_extractor</span><span class="o">=</span><span class="n">pulsemap_extractor</span><span class="p">,</span> + <span class="n">model_config</span><span class="o">=</span><span class="n">model_config</span><span class="p">,</span> + <span class="n">state_dict</span><span class="o">=</span><span class="n">state_dict</span><span class="p">,</span> + <span class="n">model_name</span><span class="o">=</span><span class="n">model_name</span><span class="p">,</span> + <span class="n">prediction_columns</span><span class="o">=</span><span class="n">prediction_columns</span><span class="p">,</span> + <span class="n">gcd_file</span><span class="o">=</span><span class="n">gcd_file</span><span class="p">,</span> + <span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_threshold</span> <span class="o">=</span> <span class="n">threshold</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_predictions_key</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">pulsemap</span><span class="si">}</span><span class="s2">_</span><span class="si">{</span><span class="n">model_name</span><span class="si">}</span><span class="s2">_Predictions"</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">pulsemap</span><span class="si">}</span><span class="s2">_</span><span class="si">{</span><span class="n">model_name</span><span class="si">}</span><span class="s2">_Pulses"</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_discard_empty_events</span> <span class="o">=</span> <span class="n">discard_empty_events</span> + + <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Add a cleaned pulsemap to frame."""</span> + <span class="c1"># inference</span> + <span class="n">gcd_file</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_gcd_file</span> + <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_make_graph</span><span class="p">(</span><span class="n">frame</span><span class="p">)</span> + <span class="k">if</span> <span class="n">graph</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="c1"># If there is no pulses to clean</span> + <span class="k">return</span> <span class="kc">False</span> + <span class="n">predictions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_inference</span><span class="p">(</span><span class="n">graph</span><span class="p">)</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_discard_empty_events</span><span class="p">:</span> + <span class="k">if</span> <span class="nb">sum</span><span class="p">(</span><span class="n">predictions</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">_threshold</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> + <span class="k">return</span> <span class="kc">False</span> + + <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span> + <span class="n">predictions</span> <span class="o">=</span> <span class="n">predictions</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> + + <span class="k">assert</span> <span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> + + <span class="c1"># Build Dictionary of predictions</span> + <span class="n">data</span> <span class="o">=</span> <span class="p">{}</span> + + <span class="n">predictions_map</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_construct_prediction_map</span><span class="p">(</span> + <span class="n">frame</span><span class="o">=</span><span class="n">frame</span><span class="p">,</span> <span class="n">predictions</span><span class="o">=</span><span class="n">predictions</span> + <span class="p">)</span> + + <span class="c1"># Adds the raw predictions to dictionary</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_predictions_key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frame</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span> + <span class="n">data</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_predictions_key</span><span class="p">]</span> <span class="o">=</span> <span class="n">predictions_map</span> + + <span class="c1"># Create a pulse map mask, indicating the pulses that are over</span> + <span class="c1"># threshold (e.g. identified as signal) and therefore should be kept</span> + <span class="c1"># Using a lambda function to evaluate which pulses to keep by</span> + <span class="c1"># checking the prediction for each pulse</span> + <span class="c1"># (Adds the actual pulsemap to dictionary)</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frame</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span> + <span class="n">data</span><span class="p">[</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span> + <span class="p">]</span> <span class="o">=</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">I3RecoPulseSeriesMapMask</span><span class="p">(</span> + <span class="n">frame</span><span class="p">,</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_pulsemap</span><span class="p">,</span> + <span class="k">lambda</span> <span class="n">om_key</span><span class="p">,</span> <span class="n">index</span><span class="p">,</span> <span class="n">pulse</span><span class="p">:</span> <span class="n">predictions_map</span><span class="p">[</span><span class="n">om_key</span><span class="p">][</span><span class="n">index</span><span class="p">]</span> + <span class="o">>=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_threshold</span><span class="p">,</span> + <span class="p">)</span> + + <span class="c1"># Submit predictions and general pulsemap</span> + <span class="n">frame</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_to_frame</span><span class="p">(</span><span class="n">frame</span><span class="o">=</span><span class="n">frame</span><span class="p">,</span> <span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">)</span> + <span class="n">data</span> <span class="o">=</span> <span class="p">{}</span> + <span class="c1"># Adds an additional pulsemap for each DOM type</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_i3_extractors</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">I3FeatureExtractorIceCubeUpgrade</span> + <span class="p">):</span> + <span class="n">mDOMMap</span><span class="p">,</span> <span class="n">DEggMap</span><span class="p">,</span> <span class="n">IceCubeMap</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_split_pulsemap_in_dom_types</span><span class="p">(</span> + <span class="n">frame</span><span class="o">=</span><span class="n">frame</span><span class="p">,</span> <span class="n">gcd_file</span><span class="o">=</span><span class="n">gcd_file</span> + <span class="p">)</span> + + <span class="k">if</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span><span class="si">}</span><span class="s2">_mDOMs_Only"</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frame</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span> + <span class="n">data</span><span class="p">[</span> + <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span><span class="si">}</span><span class="s2">_mDOMs_Only"</span> + <span class="p">]</span> <span class="o">=</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">I3RecoPulseSeriesMap</span><span class="p">(</span><span class="n">mDOMMap</span><span class="p">)</span> + + <span class="k">if</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span><span class="si">}</span><span class="s2">_dEggs_Only"</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frame</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span> + <span class="n">data</span><span class="p">[</span> + <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span><span class="si">}</span><span class="s2">_dEggs_Only"</span> + <span class="p">]</span> <span class="o">=</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">I3RecoPulseSeriesMap</span><span class="p">(</span><span class="n">DEggMap</span><span class="p">)</span> + + <span class="k">if</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span><span class="si">}</span><span class="s2">_pDOMs_Only"</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frame</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span> + <span class="n">data</span><span class="p">[</span> + <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span><span class="si">}</span><span class="s2">_pDOMs_Only"</span> + <span class="p">]</span> <span class="o">=</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">I3RecoPulseSeriesMap</span><span class="p">(</span><span class="n">IceCubeMap</span><span class="p">)</span> + + <span class="c1"># Submits the additional pulsemaps to the frame</span> + <span class="n">frame</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_to_frame</span><span class="p">(</span><span class="n">frame</span><span class="o">=</span><span class="n">frame</span><span class="p">,</span> <span class="n">data</span><span class="o">=</span><span class="n">data</span><span class="p">)</span> + + <span class="k">return</span> <span class="kc">True</span> + + <span class="k">def</span> <span class="nf">_split_pulsemap_in_dom_types</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span><span class="p">,</span> <span class="n">gcd_file</span><span class="p">:</span> <span class="n">Any</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="n">Any</span><span class="p">],</span> <span class="n">Dict</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="n">Any</span><span class="p">],</span> <span class="n">Dict</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]:</span> +<span class="w"> </span><span class="sd">"""Will split the cleaned pulsemap into multiple pulsemaps.</span> + +<span class="sd"> Arguments:</span> +<span class="sd"> frame: I3Frame (physics)</span> +<span class="sd"> gcd_file: path to associated gcd file</span> + +<span class="sd"> Returns:</span> +<span class="sd"> mDOMMap, DeGGMap, IceCubeMap</span> +<span class="sd"> """</span> + <span class="n">g</span> <span class="o">=</span> <span class="n">dataio</span><span class="o">.</span><span class="n">I3File</span><span class="p">(</span><span class="n">gcd_file</span><span class="p">)</span> + <span class="n">gFrame</span> <span class="o">=</span> <span class="n">g</span><span class="o">.</span><span class="n">pop_frame</span><span class="p">()</span> + <span class="k">while</span> <span class="s2">"I3Geometry"</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">gFrame</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span> + <span class="n">gFrame</span> <span class="o">=</span> <span class="n">g</span><span class="o">.</span><span class="n">pop_frame</span><span class="p">()</span> + <span class="n">omGeoMap</span> <span class="o">=</span> <span class="n">gFrame</span><span class="p">[</span><span class="s2">"I3Geometry"</span><span class="p">]</span><span class="o">.</span><span class="n">omgeo</span> + + <span class="n">mDOMMap</span><span class="p">,</span> <span class="n">DEggMap</span><span class="p">,</span> <span class="n">IceCubeMap</span> <span class="o">=</span> <span class="p">{},</span> <span class="p">{},</span> <span class="p">{}</span> + <span class="n">pulses</span> <span class="o">=</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">I3RecoPulseSeriesMap</span><span class="o">.</span><span class="n">from_frame</span><span class="p">(</span> + <span class="n">frame</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_total_pulsemap_name</span> + <span class="p">)</span> + <span class="k">for</span> <span class="n">P</span> <span class="ow">in</span> <span class="n">pulses</span><span class="p">:</span> + <span class="n">om</span> <span class="o">=</span> <span class="n">omGeoMap</span><span class="p">[</span><span class="n">P</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> + <span class="k">if</span> <span class="n">om</span><span class="o">.</span><span class="n">omtype</span> <span class="o">==</span> <span class="mi">130</span><span class="p">:</span> <span class="c1"># "mDOM"</span> + <span class="n">mDOMMap</span><span class="p">[</span><span class="n">P</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="o">=</span> <span class="n">P</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> + <span class="k">elif</span> <span class="n">om</span><span class="o">.</span><span class="n">omtype</span> <span class="o">==</span> <span class="mi">120</span><span class="p">:</span> <span class="c1"># "DEgg"</span> + <span class="n">DEggMap</span><span class="p">[</span><span class="n">P</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="o">=</span> <span class="n">P</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> + <span class="k">elif</span> <span class="n">om</span><span class="o">.</span><span class="n">omtype</span> <span class="o">==</span> <span class="mi">20</span><span class="p">:</span> <span class="c1"># "IceCube / pDOM"</span> + <span class="n">IceCubeMap</span><span class="p">[</span><span class="n">P</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="o">=</span> <span class="n">P</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> + <span class="k">return</span> <span class="n">mDOMMap</span><span class="p">,</span> <span class="n">DEggMap</span><span class="p">,</span> <span class="n">IceCubeMap</span> + + <span class="k">def</span> <span class="nf">_construct_prediction_map</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">frame</span><span class="p">:</span> <span class="n">I3Frame</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">I3MapKeyVectorDouble</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Make a pulsemap from predictions (for all OM types).</span> + +<span class="sd"> Arguments:</span> +<span class="sd"> frame: I3Frame (physics)</span> +<span class="sd"> predictions: predictions from GNN</span> + +<span class="sd"> Returns:</span> +<span class="sd"> predictions_map: a pulsemap from predictions</span> +<span class="sd"> """</span> + <span class="n">pulsemap</span> <span class="o">=</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">I3RecoPulseSeriesMap</span><span class="o">.</span><span class="n">from_frame</span><span class="p">(</span> + <span class="n">frame</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pulsemap</span> + <span class="p">)</span> + + <span class="n">idx</span> <span class="o">=</span> <span class="mi">0</span> + <span class="n">predictions</span> <span class="o">=</span> <span class="n">predictions</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> + <span class="n">predictions_map</span> <span class="o">=</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">I3MapKeyVectorDouble</span><span class="p">()</span> + <span class="k">for</span> <span class="n">om_key</span><span class="p">,</span> <span class="n">pulses</span> <span class="ow">in</span> <span class="n">pulsemap</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> + <span class="n">num_pulses</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">pulses</span><span class="p">)</span> + <span class="n">predictions_map</span><span class="p">[</span><span class="n">om_key</span><span class="p">]</span> <span class="o">=</span> <span class="n">predictions</span><span class="p">[</span> + <span class="n">idx</span> <span class="p">:</span> <span class="n">idx</span> <span class="o">+</span> <span class="n">num_pulses</span> + <span class="p">]</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> + <span class="n">idx</span> <span class="o">+=</span> <span class="n">num_pulses</span> + + <span class="c1"># Checks</span> + <span class="k">assert</span> <span class="n">idx</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span> + <span class="n">predictions</span> + <span class="p">),</span> <span class="s2">"""Not all predictions were mapped to pulses,</span><span class="se">\n</span> +<span class="s2"> validation of predictions have failed."""</span> + + <span class="k">assert</span> <span class="p">(</span> + <span class="n">pulsemap</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="o">==</span> <span class="n">predictions_map</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> + <span class="p">),</span> <span class="s2">"""Input pulse map and predictions map do </span><span class="se">\n</span> +<span class="s2"> not contain exactly the same OMs"""</span> + <span class="k">return</span> <span class="n">predictions_map</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/coarsening.html b/_modules/graphnet/models/coarsening.html new file mode 100644 index 000000000..6a260f076 --- /dev/null +++ b/_modules/graphnet/models/coarsening.html @@ -0,0 +1,711 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.coarsening — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" /> + <script src="../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../about.html" /> + <link rel="index" title="Index" href="../../../genindex.html" /> + <link rel="search" title="Search" href="../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/coarsening" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.coarsening </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../"versions.json"", + target_loc = "../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-coarsening--page-root">Source code for graphnet.models.coarsening</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Class(es) for coarsening operations (i.e., clustering, or local pooling)."""</span> + +<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span> +<span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span> +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">Tensor</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span> +<span class="kn">from</span> <span class="nn">sklearn.cluster</span> <span class="kn">import</span> <span class="n">DBSCAN</span> + +<span class="c1"># from torch_geometric.utils import unbatch_edge_index</span> +<span class="kn">from</span> <span class="nn">graphnet.models.components.pool</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">group_by</span><span class="p">,</span> + <span class="n">avg_pool</span><span class="p">,</span> + <span class="n">max_pool</span><span class="p">,</span> + <span class="n">min_pool</span><span class="p">,</span> + <span class="n">sum_pool</span><span class="p">,</span> + <span class="n">avg_pool_x</span><span class="p">,</span> + <span class="n">max_pool_x</span><span class="p">,</span> + <span class="n">min_pool_x</span><span class="p">,</span> + <span class="n">sum_pool_x</span><span class="p">,</span> + <span class="n">std_pool_x</span><span class="p">,</span> +<span class="p">)</span> +<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span> + +<span class="c1"># Utility method(s)</span> +<span class="kn">from</span> <span class="nn">torch_geometric.utils</span> <span class="kn">import</span> <span class="n">degree</span> + +<span class="c1"># NOTE: From [https://github.com/pyg-team/pytorch_geometric/pull/4903]</span> +<span class="c1"># TODO: Remove once bumping to torch_geometric>=2.1.0</span> +<span class="c1"># See [https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md]</span> + + +<div class="viewcode-block" id="unbatch_edge_index"> +<a class="viewcode-back" href="../../../api/graphnet.models.coarsening.html#graphnet.models.coarsening.unbatch_edge_index">[docs]</a> +<span class="k">def</span> <span class="nf">unbatch_edge_index</span><span class="p">(</span><span class="n">edge_index</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span> + <span class="c1"># noqa: D401</span> +<span class="w"> </span><span class="sa">r</span><span class="sd">"""Splits the :obj:`edge_index` according to a :obj:`batch` vector.</span> + +<span class="sd"> Args:</span> +<span class="sd"> edge_index (Tensor): The edge_index tensor. Must be ordered.</span> +<span class="sd"> batch (LongTensor): The batch vector</span> +<span class="sd"> :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each</span> +<span class="sd"> node to a specific example. Must be ordered.</span> +<span class="sd"> :rtype: :class:`List[Tensor]`</span> +<span class="sd"> """</span> + <span class="n">deg</span> <span class="o">=</span> <span class="n">degree</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span> + <span class="n">ptr</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">deg</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">deg</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> + + <span class="n">edge_batch</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">edge_index</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> + <span class="n">edge_index</span> <span class="o">=</span> <span class="n">edge_index</span> <span class="o">-</span> <span class="n">ptr</span><span class="p">[</span><span class="n">edge_batch</span><span class="p">]</span> + <span class="n">sizes</span> <span class="o">=</span> <span class="n">degree</span><span class="p">(</span><span class="n">edge_batch</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> + <span class="k">return</span> <span class="n">edge_index</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></div> + + + +<div class="viewcode-block" id="Coarsening"> +<a class="viewcode-back" href="../../../api/graphnet.models.coarsening.html#graphnet.models.coarsening.Coarsening">[docs]</a> +<span class="k">class</span> <span class="nc">Coarsening</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Base class for coarsening operations."""</span> + + <span class="c1"># Class variables</span> + <span class="n">reduce_options</span> <span class="o">=</span> <span class="p">{</span> + <span class="s2">"avg"</span><span class="p">:</span> <span class="p">(</span><span class="n">avg_pool</span><span class="p">,</span> <span class="n">avg_pool_x</span><span class="p">),</span> + <span class="s2">"min"</span><span class="p">:</span> <span class="p">(</span><span class="n">min_pool</span><span class="p">,</span> <span class="n">min_pool_x</span><span class="p">),</span> + <span class="s2">"max"</span><span class="p">:</span> <span class="p">(</span><span class="n">max_pool</span><span class="p">,</span> <span class="n">max_pool_x</span><span class="p">),</span> + <span class="s2">"sum"</span><span class="p">:</span> <span class="p">(</span><span class="n">sum_pool</span><span class="p">,</span> <span class="n">sum_pool_x</span><span class="p">),</span> + <span class="p">}</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">reduce</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"avg"</span><span class="p">,</span> + <span class="n">transfer_attributes</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct `Coarsening`."""</span> + <span class="k">assert</span> <span class="n">reduce</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduce_options</span> + + <span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_reduce_method</span><span class="p">,</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_attribute_reduce_method</span><span class="p">,</span> + <span class="p">)</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduce_options</span><span class="p">[</span><span class="n">reduce</span><span class="p">]</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_do_transfer_attributes</span> <span class="o">=</span> <span class="n">transfer_attributes</span> + + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> + + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">_perform_clustering</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">])</span> <span class="o">-></span> <span class="n">LongTensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Cluster nodes in `data` by assigning a cluster index to each."""</span> + + <span class="k">def</span> <span class="nf">_additional_features</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Batch</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Perform additional poolings of feature tensor `x` on `data`.</span> + +<span class="sd"> By default the nominal `pooling_method` is used for features as well.</span> +<span class="sd"> This method can be overwritten for bespoke coarsening operations.</span> +<span class="sd"> """</span> + + <span class="k">def</span> <span class="nf">_transfer_attributes</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">original_data</span><span class="p">:</span> <span class="n">Batch</span><span class="p">,</span> <span class="n">pooled_data</span><span class="p">:</span> <span class="n">Batch</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Batch</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Transfer attributes on `original_data` to `pooled_data`."""</span> + <span class="c1"># Check(s)</span> + <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_do_transfer_attributes</span><span class="p">:</span> + <span class="k">return</span> <span class="n">pooled_data</span> + + <span class="n">attributes</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">original_data</span><span class="o">.</span><span class="n">_store</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> + <span class="n">batch</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LongTensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">original_data</span><span class="o">.</span><span class="n">batch</span> + <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="n">attr</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">attributes</span><span class="p">):</span> + <span class="k">if</span> <span class="n">attr</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">pooled_data</span><span class="o">.</span><span class="n">_store</span><span class="p">:</span> + <span class="n">values</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">original_data</span><span class="p">,</span> <span class="n">attr</span><span class="p">)</span> + + <span class="n">attr_is_node_level_tensor</span> <span class="o">=</span> <span class="kc">False</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">values</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span> + <span class="k">if</span> <span class="n">batch</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">attr_is_node_level_tensor</span> <span class="o">=</span> <span class="p">(</span> + <span class="n">values</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">></span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">values</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span> + <span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">attr_is_node_level_tensor</span> <span class="o">=</span> <span class="p">(</span> + <span class="n">values</span><span class="o">.</span><span class="n">size</span><span class="p">()</span> <span class="o">==</span> <span class="n">original_data</span><span class="o">.</span><span class="n">batch</span><span class="o">.</span><span class="n">size</span><span class="p">()</span> + <span class="p">)</span> + + <span class="k">if</span> <span class="n">attr_is_node_level_tensor</span><span class="p">:</span> + <span class="n">values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_attribute_reduce_method</span><span class="p">(</span> + <span class="n">cluster</span><span class="p">,</span> + <span class="n">values</span><span class="p">,</span> + <span class="n">batch</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">values</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span> + <span class="p">)[</span><span class="mi">0</span><span class="p">]</span> + + <span class="nb">setattr</span><span class="p">(</span><span class="n">pooled_data</span><span class="p">,</span> <span class="n">attr</span><span class="p">,</span> <span class="n">values</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">pooled_data</span> + +<div class="viewcode-block" id="Coarsening.forward"> +<a class="viewcode-back" href="../../../api/graphnet.models.coarsening.html#graphnet.models.coarsening.Coarsening.forward">[docs]</a> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">])</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Perform coarsening operation."""</span> + <span class="c1"># Get tensor of cluster indices for each node.</span> + <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_perform_clustering</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> + + <span class="c1"># Check whether a graph has already been built. Otherwise, set a dummy</span> + <span class="c1"># connectivity, as this is required by pooling functions.</span> + <span class="n">edge_index</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span> + <span class="k">if</span> <span class="n">edge_index</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[]],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span> + + <span class="c1"># Pool `data` object, including `x`, `batch`. and `edge_index`.</span> + <span class="n">pooled_data</span><span class="p">:</span> <span class="n">Batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_reduce_method</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span> + + <span class="c1"># Optionally overwrite feature tensor</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_additional_features</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span> + <span class="k">if</span> <span class="n">x</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">pooled_data</span><span class="o">.</span><span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span> + <span class="p">(</span> + <span class="n">pooled_data</span><span class="o">.</span><span class="n">x</span><span class="p">,</span> + <span class="n">x</span><span class="p">,</span> + <span class="p">),</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="c1"># Reset `edge_index` if necessary.</span> + <span class="k">if</span> <span class="n">edge_index</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span> <span class="o">=</span> <span class="n">edge_index</span> + <span class="n">pooled_data</span><span class="o">.</span><span class="n">edge_index</span> <span class="o">=</span> <span class="n">edge_index</span> + + <span class="c1"># Transfer attributes on `data`, pooling as required.</span> + <span class="n">pooled_data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transfer_attributes</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">pooled_data</span><span class="p">)</span> + + <span class="c1"># Reconstruct Batch Attributes</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">):</span> <span class="c1"># if a Batch object</span> + <span class="n">pooled_data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_reconstruct_batch</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">pooled_data</span><span class="p">)</span> + <span class="k">return</span> <span class="n">pooled_data</span></div> + + + <span class="k">def</span> <span class="nf">_reconstruct_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">original</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">pooled</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> + <span class="n">pooled</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_slice_dict</span><span class="p">(</span><span class="n">original</span><span class="p">,</span> <span class="n">pooled</span><span class="p">)</span> + <span class="n">pooled</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_inc_dict</span><span class="p">(</span><span class="n">original</span><span class="p">,</span> <span class="n">pooled</span><span class="p">)</span> + <span class="k">return</span> <span class="n">pooled</span> + + <span class="k">def</span> <span class="nf">_add_slice_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">original</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">pooled</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> + <span class="c1"># Copy original slice_dict and count nodes in each graph in pooled batch</span> + <span class="n">slice_dict</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">original</span><span class="o">.</span><span class="n">_slice_dict</span><span class="p">)</span> + <span class="n">_</span><span class="p">,</span> <span class="n">counts</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unique_consecutive</span><span class="p">(</span><span class="n">pooled</span><span class="o">.</span><span class="n">batch</span><span class="p">,</span> <span class="n">return_counts</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> + <span class="c1"># Reconstruct the entry in slice_dict for pulsemaps - only these are affected by pooling</span> + <span class="n">pulsemap_slice</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> + <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">counts</span><span class="p">)):</span> + <span class="n">pulsemap_slice</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">pulsemap_slice</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">counts</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">())</span> + + <span class="c1"># Identifies pulsemap entries in slice_dict and set them to pulsemap_slice</span> + <span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="n">slice_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span> + <span class="k">if</span> <span class="p">(</span><span class="n">original</span><span class="o">.</span><span class="n">_num_graphs</span><span class="p">)</span> <span class="o">==</span> <span class="n">slice_dict</span><span class="p">[</span><span class="n">field</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span> + <span class="k">pass</span> <span class="c1"># not pulsemap, so skip</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">slice_dict</span><span class="p">[</span><span class="n">field</span><span class="p">]</span> <span class="o">=</span> <span class="n">pulsemap_slice</span> + <span class="n">pooled</span><span class="o">.</span><span class="n">_slice_dict</span> <span class="o">=</span> <span class="n">slice_dict</span> + <span class="k">return</span> <span class="n">pooled</span> + + <span class="k">def</span> <span class="nf">_add_inc_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">original</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">pooled</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> + <span class="c1"># not changed by coarsening</span> + <span class="n">pooled</span><span class="o">.</span><span class="n">_inc_dict</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">original</span><span class="o">.</span><span class="n">_inc_dict</span><span class="p">)</span> + <span class="k">return</span> <span class="n">pooled</span></div> + + + +<div class="viewcode-block" id="AttributeCoarsening"> +<a class="viewcode-back" href="../../../api/graphnet.models.coarsening.html#graphnet.models.coarsening.AttributeCoarsening">[docs]</a> +<span class="k">class</span> <span class="nc">AttributeCoarsening</span><span class="p">(</span><span class="n">Coarsening</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Coarsen pulses based on specified attributes."""</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">attributes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="n">reduce</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"avg"</span><span class="p">,</span> + <span class="n">transfer_attributes</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct `SimpleCoarsening`."""</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_attributes</span> <span class="o">=</span> <span class="n">attributes</span> + + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">reduce</span><span class="p">,</span> <span class="n">transfer_attributes</span><span class="p">)</span> + + <span class="k">def</span> <span class="nf">_perform_clustering</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">])</span> <span class="o">-></span> <span class="n">LongTensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Cluster nodes in `data` by assigning a cluster index to each."""</span> + <span class="n">dom_index</span> <span class="o">=</span> <span class="n">group_by</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_attributes</span><span class="p">)</span> + <span class="k">return</span> <span class="n">dom_index</span></div> + + + +<div class="viewcode-block" id="DOMCoarsening"> +<a class="viewcode-back" href="../../../api/graphnet.models.coarsening.html#graphnet.models.coarsening.DOMCoarsening">[docs]</a> +<span class="k">class</span> <span class="nc">DOMCoarsening</span><span class="p">(</span><span class="n">Coarsening</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Coarsen pulses to DOM-level."""</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">reduce</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"avg"</span><span class="p">,</span> + <span class="n">transfer_attributes</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> + <span class="n">keys</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Cluster pulses on the same DOM."""</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">reduce</span><span class="p">,</span> <span class="n">transfer_attributes</span><span class="p">)</span> + <span class="k">if</span> <span class="n">keys</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_keys</span> <span class="o">=</span> <span class="p">[</span> + <span class="s2">"dom_x"</span><span class="p">,</span> + <span class="s2">"dom_y"</span><span class="p">,</span> + <span class="s2">"dom_z"</span><span class="p">,</span> + <span class="s2">"rde"</span><span class="p">,</span> + <span class="s2">"pmt_area"</span><span class="p">,</span> + <span class="p">]</span> + <span class="k">else</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_keys</span> <span class="o">=</span> <span class="n">keys</span> + + <span class="k">def</span> <span class="nf">_perform_clustering</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">])</span> <span class="o">-></span> <span class="n">LongTensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Cluster nodes in `data` by assigning a cluster index to each."""</span> + <span class="n">dom_index</span> <span class="o">=</span> <span class="n">group_by</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_keys</span><span class="p">)</span> + <span class="k">return</span> <span class="n">dom_index</span></div> + + + +<div class="viewcode-block" id="CustomDOMCoarsening"> +<a class="viewcode-back" href="../../../api/graphnet.models.coarsening.html#graphnet.models.coarsening.CustomDOMCoarsening">[docs]</a> +<span class="k">class</span> <span class="nc">CustomDOMCoarsening</span><span class="p">(</span><span class="n">DOMCoarsening</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Coarsen pulses to DOM-level with additional attributes."""</span> + + <span class="k">def</span> <span class="nf">_additional_features</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Perform Additional poolings of feature tensor `x` on `data`."""</span> + <span class="n">batch</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span> + + <span class="n">features</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">features</span> + <span class="k">if</span> <span class="n">batch</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">features</span> <span class="o">=</span> <span class="p">[</span><span class="n">feats</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">feats</span> <span class="ow">in</span> <span class="n">features</span><span class="p">]</span> + + <span class="n">ix_time</span> <span class="o">=</span> <span class="n">features</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="s2">"dom_time"</span><span class="p">)</span> + <span class="n">ix_charge</span> <span class="o">=</span> <span class="n">features</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="s2">"charge"</span><span class="p">)</span> + + <span class="n">time</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="n">ix_time</span><span class="p">]</span> + <span class="n">charge</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="n">ix_charge</span><span class="p">]</span> + + <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span> + <span class="p">(</span> + <span class="n">min_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">time</span><span class="p">,</span> <span class="n">batch</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span> + <span class="n">max_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">time</span><span class="p">,</span> <span class="n">batch</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span> + <span class="n">std_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">time</span><span class="p">,</span> <span class="n">batch</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span> + <span class="n">min_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">charge</span><span class="p">,</span> <span class="n">batch</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span> + <span class="n">max_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">charge</span><span class="p">,</span> <span class="n">batch</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span> + <span class="n">std_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">charge</span><span class="p">,</span> <span class="n">batch</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span> + <span class="n">sum_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">charge</span><span class="p">),</span> <span class="n">batch</span><span class="p">)[</span> + <span class="mi">0</span> + <span class="p">],</span> <span class="c1"># Num. nodes (pulses) per cluster (DOM)</span> + <span class="p">),</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="k">return</span> <span class="n">x</span></div> + + + +<div class="viewcode-block" id="DOMAndTimeWindowCoarsening"> +<a class="viewcode-back" href="../../../api/graphnet.models.coarsening.html#graphnet.models.coarsening.DOMAndTimeWindowCoarsening">[docs]</a> +<span class="k">class</span> <span class="nc">DOMAndTimeWindowCoarsening</span><span class="p">(</span><span class="n">Coarsening</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Coarsen pulses to DOM-level, with additional time-window clustering."""</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">time_window</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> + <span class="n">reduce</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"avg"</span><span class="p">,</span> + <span class="n">transfer_attributes</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> + <span class="n">keys</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span> + <span class="s2">"dom_x"</span><span class="p">,</span> + <span class="s2">"dom_y"</span><span class="p">,</span> + <span class="s2">"dom_z"</span><span class="p">,</span> + <span class="s2">"rde"</span><span class="p">,</span> + <span class="s2">"pmt_area"</span><span class="p">,</span> + <span class="p">],</span> + <span class="n">time_key</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"dom_time"</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Cluster pulses on the same DOM within `time_window`."""</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">reduce</span><span class="p">,</span> <span class="n">transfer_attributes</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_time_window</span> <span class="o">=</span> <span class="n">time_window</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_cluster_method</span> <span class="o">=</span> <span class="n">DBSCAN</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_time_window</span><span class="p">,</span> <span class="n">min_samples</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_keys</span> <span class="o">=</span> <span class="n">keys</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_time_key</span> <span class="o">=</span> <span class="n">time_key</span> + + <span class="k">def</span> <span class="nf">_perform_clustering</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">])</span> <span class="o">-></span> <span class="n">LongTensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Cluster nodes in `data` by assigning a cluster index to each."""</span> + <span class="n">dom_index</span> <span class="o">=</span> <span class="n">group_by</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_keys</span><span class="p">)</span> + <span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">features</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">features</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">features</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">features</span> + + <span class="n">ix_time</span> <span class="o">=</span> <span class="n">features</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_time_key</span><span class="p">)</span> + <span class="n">hit_times</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="n">ix_time</span><span class="p">]</span> + + <span class="c1"># Scale up dom_index to make sure clusters are well separated</span> + <span class="n">times_and_domids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span> + <span class="p">[</span> + <span class="n">hit_times</span><span class="p">,</span> + <span class="n">dom_index</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">_time_window</span> <span class="o">*</span> <span class="mi">10</span><span class="p">,</span> + <span class="p">]</span> + <span class="p">)</span><span class="o">.</span><span class="n">T</span> + <span class="n">clusters</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_cluster_method</span><span class="o">.</span><span class="n">fit_predict</span><span class="p">(</span><span class="n">times_and_domids</span><span class="o">.</span><span class="n">cpu</span><span class="p">()),</span> + <span class="n">device</span><span class="o">=</span><span class="n">hit_times</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> + <span class="p">)</span> + + <span class="k">return</span> <span class="n">clusters</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/components/layers.html b/_modules/graphnet/models/components/layers.html new file mode 100644 index 000000000..555ac857b --- /dev/null +++ b/_modules/graphnet/models/components/layers.html @@ -0,0 +1,579 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.components.layers — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/components/layers" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.components.layers </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-components-layers--page-root">Source code for graphnet.models.components.layers</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Class(es) implementing layers to be used in `graphnet` models."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span> + +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch.functional</span> <span class="kn">import</span> <span class="n">Tensor</span> +<span class="kn">from</span> <span class="nn">torch_geometric.nn</span> <span class="kn">import</span> <span class="n">EdgeConv</span> +<span class="kn">from</span> <span class="nn">torch_geometric.nn.pool</span> <span class="kn">import</span> <span class="n">knn_graph</span> +<span class="kn">from</span> <span class="nn">torch_geometric.typing</span> <span class="kn">import</span> <span class="n">Adj</span><span class="p">,</span> <span class="n">PairTensor</span> +<span class="kn">from</span> <span class="nn">torch_geometric.nn.conv</span> <span class="kn">import</span> <span class="n">MessagePassing</span> +<span class="kn">from</span> <span class="nn">torch_geometric.nn.inits</span> <span class="kn">import</span> <span class="n">reset</span> +<span class="kn">from</span> <span class="nn">torch.nn.modules</span> <span class="kn">import</span> <span class="n">TransformerEncoder</span><span class="p">,</span> <span class="n">TransformerEncoderLayer</span> +<span class="kn">from</span> <span class="nn">torch.nn.modules.normalization</span> <span class="kn">import</span> <span class="n">LayerNorm</span> +<span class="kn">from</span> <span class="nn">torch_geometric.utils</span> <span class="kn">import</span> <span class="n">to_dense_batch</span> +<span class="kn">from</span> <span class="nn">pytorch_lightning</span> <span class="kn">import</span> <span class="n">LightningModule</span> + + +<div class="viewcode-block" id="DynEdgeConv"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynEdgeConv">[docs]</a> +<span class="k">class</span> <span class="nc">DynEdgeConv</span><span class="p">(</span><span class="n">EdgeConv</span><span class="p">,</span> <span class="n">LightningModule</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Dynamical edge convolution layer."""</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">nn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">,</span> + <span class="n">aggr</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"max"</span><span class="p">,</span> + <span class="n">nb_neighbors</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span> + <span class="n">features_subset</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">slice</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct `DynEdgeConv`.</span> + +<span class="sd"> Args:</span> +<span class="sd"> nn: The MLP/torch.Module to be used within the `EdgeConv`.</span> +<span class="sd"> aggr: Aggregation method to be used with `EdgeConv`.</span> +<span class="sd"> nb_neighbors: Number of neighbours to be clustered after the</span> +<span class="sd"> `EdgeConv` operation.</span> +<span class="sd"> features_subset: Subset of features in `Data.x` that should be used</span> +<span class="sd"> when dynamically performing the new graph clustering after the</span> +<span class="sd"> `EdgeConv` operation. Defaults to all features.</span> +<span class="sd"> **kwargs: Additional features to be passed to `EdgeConv`.</span> +<span class="sd"> """</span> + <span class="c1"># Check(s)</span> + <span class="k">if</span> <span class="n">features_subset</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">features_subset</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span> <span class="c1"># Use all features</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">features_subset</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">slice</span><span class="p">))</span> + + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">nn</span><span class="o">=</span><span class="n">nn</span><span class="p">,</span> <span class="n">aggr</span><span class="o">=</span><span class="n">aggr</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> + + <span class="c1"># Additional member variables</span> + <span class="bp">self</span><span class="o">.</span><span class="n">nb_neighbors</span> <span class="o">=</span> <span class="n">nb_neighbors</span> + <span class="bp">self</span><span class="o">.</span><span class="n">features_subset</span> <span class="o">=</span> <span class="n">features_subset</span> + +<div class="viewcode-block" id="DynEdgeConv.forward"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynEdgeConv.forward">[docs]</a> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">:</span> <span class="n">Adj</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Forward pass."""</span> + <span class="c1"># Standard EdgeConv forward pass</span> + <span class="n">x</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">)</span> + + <span class="c1"># Recompute adjacency</span> + <span class="n">edge_index</span> <span class="o">=</span> <span class="n">knn_graph</span><span class="p">(</span> + <span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">features_subset</span><span class="p">],</span> + <span class="n">k</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_neighbors</span><span class="p">,</span> + <span class="n">batch</span><span class="o">=</span><span class="n">batch</span><span class="p">,</span> + <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span></div> +</div> + + + +<div class="viewcode-block" id="EdgeConvTito"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito">[docs]</a> +<span class="k">class</span> <span class="nc">EdgeConvTito</span><span class="p">(</span><span class="n">MessagePassing</span><span class="p">,</span> <span class="n">LightningModule</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Implementation of EdgeConvTito layer used in TITO solution for.</span> + +<span class="sd"> 'IceCube - Neutrinos in Deep' kaggle competition.</span> +<span class="sd"> """</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">nn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">,</span> + <span class="n">aggr</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"max"</span><span class="p">,</span> + <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct `EdgeConvTito`.</span> + +<span class="sd"> Args:</span> +<span class="sd"> nn: The MLP/torch.Module to be used within the `EdgeConvTito`.</span> +<span class="sd"> aggr: Aggregation method to be used with `EdgeConvTito`.</span> +<span class="sd"> **kwargs: Additional features to be passed to `EdgeConvTito`.</span> +<span class="sd"> """</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">aggr</span><span class="o">=</span><span class="n">aggr</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">nn</span> <span class="o">=</span> <span class="n">nn</span> + <span class="bp">self</span><span class="o">.</span><span class="n">reset_parameters</span><span class="p">()</span> + +<div class="viewcode-block" id="EdgeConvTito.reset_parameters"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito.reset_parameters">[docs]</a> + <span class="k">def</span> <span class="nf">reset_parameters</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Reset all learnable parameters of the module."""</span> + <span class="n">reset</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nn</span><span class="p">)</span></div> + + +<div class="viewcode-block" id="EdgeConvTito.forward"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito.forward">[docs]</a> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">PairTensor</span><span class="p">],</span> <span class="n">edge_index</span><span class="p">:</span> <span class="n">Adj</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Forward pass."""</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span> + <span class="n">x</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> + <span class="c1"># propagate_type: (x: PairTensor)</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">propagate</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span></div> + + +<div class="viewcode-block" id="EdgeConvTito.message"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito.message">[docs]</a> + <span class="k">def</span> <span class="nf">message</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x_i</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">x_j</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Edgeconvtito message passing."""</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">nn</span><span class="p">(</span> + <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">x_i</span><span class="p">,</span> <span class="n">x_j</span> <span class="o">-</span> <span class="n">x_i</span><span class="p">,</span> <span class="n">x_j</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> + <span class="p">)</span> <span class="c1"># EdgeConvTito</span></div> + + + <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Print out module name."""</span> + <span class="k">return</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">(nn=</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">nn</span><span class="si">}</span><span class="s2">)"</span></div> + + + +<div class="viewcode-block" id="DynTrans"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynTrans">[docs]</a> +<span class="k">class</span> <span class="nc">DynTrans</span><span class="p">(</span><span class="n">EdgeConvTito</span><span class="p">,</span> <span class="n">LightningModule</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Implementation of dynTrans1 layer used in TITO solution for.</span> + +<span class="sd"> 'IceCube - Neutrinos in Deep' kaggle competition.</span> +<span class="sd"> """</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">layer_sizes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">aggr</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"max"</span><span class="p">,</span> + <span class="n">features_subset</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">slice</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">n_head</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span> + <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct `DynTrans`.</span> + +<span class="sd"> Args:</span> +<span class="sd"> nn: The MLP/torch.Module to be used within the `DynTrans`.</span> +<span class="sd"> layer_sizes: List of layer sizes to be used in `DynTrans`.</span> +<span class="sd"> aggr: Aggregation method to be used with `DynTrans`.</span> +<span class="sd"> features_subset: Subset of features in `Data.x` that should be used</span> +<span class="sd"> when dynamically performing the new graph clustering after the</span> +<span class="sd"> `EdgeConv` operation. Defaults to all features.</span> +<span class="sd"> n_head: Number of heads to be used in the multiheadattention models.</span> +<span class="sd"> **kwargs: Additional features to be passed to `DynTrans`.</span> +<span class="sd"> """</span> + <span class="c1"># Check(s)</span> + <span class="k">if</span> <span class="n">features_subset</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">features_subset</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span> <span class="c1"># Use all features</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">features_subset</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">slice</span><span class="p">))</span> + + <span class="k">if</span> <span class="n">layer_sizes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">256</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">256</span><span class="p">]</span> + <span class="n">layers</span> <span class="o">=</span> <span class="p">[]</span> + <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span> + <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span> + <span class="p">):</span> + <span class="k">if</span> <span class="n">ix</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> + <span class="n">nb_in</span> <span class="o">*=</span> <span class="mi">3</span> <span class="c1"># edgeConv1</span> + <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">))</span> + <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">())</span> + <span class="n">d_model</span> <span class="o">=</span> <span class="n">nb_out</span> + + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">nn</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">),</span> <span class="n">aggr</span><span class="o">=</span><span class="n">aggr</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> + + <span class="c1"># Additional member variables</span> + <span class="bp">self</span><span class="o">.</span><span class="n">features_subset</span> <span class="o">=</span> <span class="n">features_subset</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">LayerNorm</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)</span> <span class="c1"># lNorm</span> + + <span class="c1"># Transformer layer(s)</span> + <span class="n">encoder_layer</span> <span class="o">=</span> <span class="n">TransformerEncoderLayer</span><span class="p">(</span> + <span class="n">d_model</span><span class="o">=</span><span class="n">d_model</span><span class="p">,</span> + <span class="n">nhead</span><span class="o">=</span><span class="n">n_head</span><span class="p">,</span> + <span class="n">batch_first</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> + <span class="n">norm_first</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> + <span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_transformer_encoder</span> <span class="o">=</span> <span class="n">TransformerEncoder</span><span class="p">(</span> + <span class="n">encoder_layer</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">1</span> + <span class="p">)</span> + +<div class="viewcode-block" id="DynTrans.forward"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynTrans.forward">[docs]</a> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">:</span> <span class="n">Adj</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Forward pass."""</span> + <span class="n">x_out</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">)</span> + + <span class="k">if</span> <span class="n">x_out</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">x_out</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">x_out</span> + + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># lNorm</span> + + <span class="c1"># Transformer layer</span> + <span class="n">x</span><span class="p">,</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">to_dense_batch</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transformer_encoder</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">src_key_padding_mask</span><span class="o">=~</span><span class="n">mask</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span> + + <span class="k">return</span> <span class="n">x</span></div> +</div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/components/pool.html b/_modules/graphnet/models/components/pool.html new file mode 100644 index 000000000..2523dfb69 --- /dev/null +++ b/_modules/graphnet/models/components/pool.html @@ -0,0 +1,656 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.components.pool — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/components/pool" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.components.pool </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-components-pool--page-root">Source code for graphnet.models.components.pool</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Functions for performing pooling/clustering/coarsening."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span> + +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">Tensor</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span> +<span class="kn">from</span> <span class="nn">torch_geometric.nn.pool.consecutive</span> <span class="kn">import</span> <span class="n">consecutive_cluster</span> +<span class="kn">from</span> <span class="nn">torch_geometric.nn.pool.pool</span> <span class="kn">import</span> <span class="n">pool_edge</span><span class="p">,</span> <span class="n">pool_batch</span><span class="p">,</span> <span class="n">pool_pos</span> +<span class="kn">from</span> <span class="nn">torch_scatter</span> <span class="kn">import</span> <span class="n">scatter</span><span class="p">,</span> <span class="n">scatter_std</span> + +<span class="kn">from</span> <span class="nn">torch_geometric.nn.pool</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">avg_pool</span><span class="p">,</span> + <span class="n">max_pool</span><span class="p">,</span> + <span class="n">avg_pool_x</span><span class="p">,</span> + <span class="n">max_pool_x</span><span class="p">,</span> +<span class="p">)</span> + + +<div class="viewcode-block" id="min_pool"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.min_pool">[docs]</a> +<span class="k">def</span> <span class="nf">min_pool</span><span class="p">(</span> + <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">transform</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Any</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Perform min-pooling of `Data`.</span> + +<span class="sd"> Like `max_pool, just negating `data.x`.</span> +<span class="sd"> """</span> + <span class="n">data</span><span class="o">.</span><span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="n">data</span><span class="o">.</span><span class="n">x</span> + <span class="n">data_pooled</span> <span class="o">=</span> <span class="n">max_pool</span><span class="p">(</span> + <span class="n">cluster</span><span class="p">,</span> + <span class="n">data</span><span class="p">,</span> + <span class="n">transform</span><span class="p">,</span> + <span class="p">)</span> + <span class="n">data</span><span class="o">.</span><span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="n">data</span><span class="o">.</span><span class="n">x</span> + <span class="n">data_pooled</span><span class="o">.</span><span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="n">data_pooled</span><span class="o">.</span><span class="n">x</span> + <span class="k">return</span> <span class="n">data_pooled</span></div> + + + +<div class="viewcode-block" id="min_pool_x"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.min_pool_x">[docs]</a> +<span class="k">def</span> <span class="nf">min_pool_x</span><span class="p">(</span> + <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> + <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> + <span class="n">batch</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> + <span class="n">size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Perform min-pooling of `Tensor`.</span> + +<span class="sd"> Like `max_pool_x, just negating `x`.</span> +<span class="sd"> """</span> + <span class="n">ret</span> <span class="o">=</span> <span class="n">max_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="o">-</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">size</span><span class="p">)</span> + <span class="k">if</span> <span class="n">size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="k">return</span> <span class="p">(</span><span class="o">-</span><span class="n">ret</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">ret</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">return</span> <span class="o">-</span><span class="n">ret</span></div> + + + +<div class="viewcode-block" id="sum_pool_and_distribute"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool_and_distribute">[docs]</a> +<span class="k">def</span> <span class="nf">sum_pool_and_distribute</span><span class="p">(</span> + <span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> + <span class="n">cluster_index</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> + <span class="n">batch</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LongTensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Sum-pool values and distribute result to the individual nodes."""</span> + <span class="k">if</span> <span class="n">batch</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">batch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span><span class="o">.</span><span class="n">long</span><span class="p">()</span> + <span class="n">tensor_pooled</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">sum_pool_x</span><span class="p">(</span><span class="n">cluster_index</span><span class="p">,</span> <span class="n">tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + <span class="n">inv</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">consecutive_cluster</span><span class="p">(</span><span class="n">cluster_index</span><span class="p">)</span> + <span class="n">tensor_unpooled</span> <span class="o">=</span> <span class="n">tensor_pooled</span><span class="p">[</span><span class="n">inv</span><span class="p">]</span> + <span class="k">return</span> <span class="n">tensor_unpooled</span></div> + + + +<span class="k">def</span> <span class="nf">_group_identical</span><span class="p">(</span> + <span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LongTensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">LongTensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Group rows in `tensor` that are identical.</span> + +<span class="sd"> Args:</span> +<span class="sd"> tensor: Tensor of shape [N, F].</span> +<span class="sd"> batch: Batch indices, to only group identical rows within batches.</span> + +<span class="sd"> Returns:</span> +<span class="sd"> List of group indices, from 0 to num. groups - 1, assigning all</span> +<span class="sd"> identical rows to the same group.</span> +<span class="sd"> """</span> + <span class="k">if</span> <span class="n">batch</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">batch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span> <span class="n">tensor</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">return_inverse</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="nb">sorted</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span> + + +<div class="viewcode-block" id="group_by"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.group_by">[docs]</a> +<span class="k">def</span> <span class="nf">group_by</span><span class="p">(</span><span class="n">data</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Data</span><span class="p">,</span> <span class="n">Batch</span><span class="p">],</span> <span class="n">keys</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">])</span> <span class="o">-></span> <span class="n">LongTensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Group nodes in `data` that have identical values of `keys`.</span> + +<span class="sd"> This grouping is done with in each event in case of batching. This allows</span> +<span class="sd"> for, e.g., assigning the same index to all pulses on the same PMT or DOM in</span> +<span class="sd"> the same event. This can be used for coarsening graphs, e.g., from pulse-</span> +<span class="sd"> level to DOM-level by aggregating feature across each group returned by this</span> +<span class="sd"> method.</span> + +<span class="sd"> Example:</span> +<span class="sd"> Given:</span> +<span class="sd"> data.f1 = [1,1,2,2,2]</span> +<span class="sd"> data.f2 = [6,7,7,7,8]</span> +<span class="sd"> Calls:</span> +<span class="sd"> groupby(data, ['f1']) -> [0, 0, 1, 1, 1]</span> +<span class="sd"> groupby(data, ['f2']) -> [0, 1, 1, 1, 2]</span> +<span class="sd"> groupby(data, ['f1', 'f2']) -> [0, 1, 2, 2, 3]</span> +<span class="sd"> """</span> + <span class="n">features</span> <span class="o">=</span> <span class="p">[</span><span class="nb">getattr</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">keys</span><span class="p">]</span> + <span class="n">tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">features</span><span class="p">)</span><span class="o">.</span><span class="n">T</span> <span class="c1"># .int() @TODO: Required? Use rounding?</span> + <span class="n">batch</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="s2">"batch"</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> + <span class="n">index</span> <span class="o">=</span> <span class="n">_group_identical</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + <span class="k">return</span> <span class="n">index</span></div> + + + +<div class="viewcode-block" id="group_pulses_to_dom"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.group_pulses_to_dom">[docs]</a> +<span class="k">def</span> <span class="nf">group_pulses_to_dom</span><span class="p">(</span><span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Group pulses on the same DOM, using DOM and string number."""</span> + <span class="n">data</span><span class="o">.</span><span class="n">dom_index</span> <span class="o">=</span> <span class="n">group_by</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="p">[</span><span class="s2">"dom_number"</span><span class="p">,</span> <span class="s2">"string"</span><span class="p">])</span> + <span class="k">return</span> <span class="n">data</span></div> + + + +<div class="viewcode-block" id="group_pulses_to_pmt"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.group_pulses_to_pmt">[docs]</a> +<span class="k">def</span> <span class="nf">group_pulses_to_pmt</span><span class="p">(</span><span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Group pulses on the same PMT, using PMT, DOM, and string number."""</span> + <span class="n">data</span><span class="o">.</span><span class="n">pmt_index</span> <span class="o">=</span> <span class="n">group_by</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="p">[</span><span class="s2">"pmt_number"</span><span class="p">,</span> <span class="s2">"dom_number"</span><span class="p">,</span> <span class="s2">"string"</span><span class="p">])</span> + <span class="k">return</span> <span class="n">data</span></div> + + + +<span class="c1"># Below mirroring `torch_geometric.nn.pool.{avg,max}_pool.py`.</span> +<span class="k">def</span> <span class="nf">_sum_pool_x</span><span class="p">(</span> + <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">scatter</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">cluster</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim_size</span><span class="o">=</span><span class="n">size</span><span class="p">,</span> <span class="n">reduce</span><span class="o">=</span><span class="s2">"sum"</span><span class="p">)</span> + + +<span class="k">def</span> <span class="nf">_std_pool_x</span><span class="p">(</span> + <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">scatter_std</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">cluster</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim_size</span><span class="o">=</span><span class="n">size</span><span class="p">,</span> <span class="n">unbiased</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> + + +<div class="viewcode-block" id="sum_pool_x"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool_x">[docs]</a> +<span class="k">def</span> <span class="nf">sum_pool_x</span><span class="p">(</span> + <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> + <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> + <span class="n">batch</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> + <span class="n">size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sa">r</span><span class="sd">"""Sum-pool node features according to the clustering defined in `cluster`.</span> + +<span class="sd"> Args:</span> +<span class="sd"> cluster: Cluster vector :math:`\mathbf{c} \in \{ 0,</span> +<span class="sd"> \ldots, N - 1 \}^N`, which assigns each node to a specific cluster.</span> +<span class="sd"> x: Node feature matrix</span> +<span class="sd"> :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.</span> +<span class="sd"> batch: Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,</span> +<span class="sd"> B-1\}}^N`, which assigns each node to a specific example.</span> +<span class="sd"> size: The maximum number of clusters in a single</span> +<span class="sd"> example. This property is useful to obtain a batch-wise dense</span> +<span class="sd"> representation, *e.g.* for applying FC layers, but should only be</span> +<span class="sd"> used if the size of the maximum number of clusters per example is</span> +<span class="sd"> known in advance.</span> +<span class="sd"> """</span> + <span class="k">if</span> <span class="n">size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">batch_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span> <span class="o">+</span> <span class="mi">1</span> + <span class="k">return</span> <span class="n">_sum_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">size</span><span class="p">),</span> <span class="kc">None</span> + + <span class="n">cluster</span><span class="p">,</span> <span class="n">perm</span> <span class="o">=</span> <span class="n">consecutive_cluster</span><span class="p">(</span><span class="n">cluster</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">_sum_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> + <span class="n">batch</span> <span class="o">=</span> <span class="n">pool_batch</span><span class="p">(</span><span class="n">perm</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">batch</span></div> + + + +<div class="viewcode-block" id="std_pool_x"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.std_pool_x">[docs]</a> +<span class="k">def</span> <span class="nf">std_pool_x</span><span class="p">(</span> + <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> + <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> + <span class="n">batch</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> + <span class="n">size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sa">r</span><span class="sd">"""Std-pool node features according to the clustering defined in `cluster`.</span> + +<span class="sd"> Args:</span> +<span class="sd"> cluster: Cluster vector :math:`\mathbf{c} \in \{ 0,</span> +<span class="sd"> \ldots, N - 1 \}^N`, which assigns each node to a specific cluster.</span> +<span class="sd"> x: Node feature matrix</span> +<span class="sd"> :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.</span> +<span class="sd"> batch: Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,</span> +<span class="sd"> B-1\}}^N`, which assigns each node to a specific example.</span> +<span class="sd"> size: The maximum number of clusters in a single</span> +<span class="sd"> example. This property is useful to obtain a batch-wise dense</span> +<span class="sd"> representation, *e.g.* for applying FC layers, but should only be</span> +<span class="sd"> used if the size of the maximum number of clusters per example is</span> +<span class="sd"> known in advance.</span> +<span class="sd"> """</span> + <span class="k">if</span> <span class="n">size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">batch_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span> <span class="o">+</span> <span class="mi">1</span> + <span class="k">return</span> <span class="n">_std_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">size</span><span class="p">),</span> <span class="kc">None</span> + + <span class="n">cluster</span><span class="p">,</span> <span class="n">perm</span> <span class="o">=</span> <span class="n">consecutive_cluster</span><span class="p">(</span><span class="n">cluster</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">_std_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> + <span class="n">batch</span> <span class="o">=</span> <span class="n">pool_batch</span><span class="p">(</span><span class="n">perm</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">batch</span></div> + + + +<div class="viewcode-block" id="sum_pool"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool">[docs]</a> +<span class="k">def</span> <span class="nf">sum_pool</span><span class="p">(</span> + <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">transform</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sa">r</span><span class="sd">"""Pool and coarsen graph according to the clustering defined in `cluster`.</span> + +<span class="sd"> All nodes within the same cluster will be represented as one node.</span> +<span class="sd"> Final node features are defined by the *sum* of features of all nodes</span> +<span class="sd"> within the same cluster, node positions are averaged and edge indices are</span> +<span class="sd"> defined to be the union of the edge indices of all nodes within the same</span> +<span class="sd"> cluster.</span> + +<span class="sd"> Args:</span> +<span class="sd"> cluster: Cluster vector :math:`\mathbf{c} \in \{ 0,</span> +<span class="sd"> \ldots, N - 1 \}^N`, which assigns each node to a specific cluster.</span> +<span class="sd"> data: Graph data object.</span> +<span class="sd"> transform: A function/transform that takes in the</span> +<span class="sd"> coarsened and pooled :obj:`torch_geometric.data.Data` object and</span> +<span class="sd"> returns a transformed version.</span> +<span class="sd"> """</span> + <span class="n">cluster</span><span class="p">,</span> <span class="n">perm</span> <span class="o">=</span> <span class="n">consecutive_cluster</span><span class="p">(</span><span class="n">cluster</span><span class="p">)</span> + + <span class="n">x</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">_sum_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">)</span> + <span class="n">index</span><span class="p">,</span> <span class="n">attr</span> <span class="o">=</span> <span class="n">pool_edge</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_attr</span><span class="p">)</span> + <span class="n">batch</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">pool_batch</span><span class="p">(</span><span class="n">perm</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span><span class="p">)</span> + <span class="n">pos</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">pos</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">pool_pos</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">pos</span><span class="p">)</span> + + <span class="n">data</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">batch</span><span class="o">=</span><span class="n">batch</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="o">=</span><span class="n">index</span><span class="p">,</span> <span class="n">edge_attr</span><span class="o">=</span><span class="n">attr</span><span class="p">,</span> <span class="n">pos</span><span class="o">=</span><span class="n">pos</span><span class="p">)</span> + + <span class="k">if</span> <span class="n">transform</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">data</span> <span class="o">=</span> <span class="n">transform</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">data</span></div> + + + +<div class="viewcode-block" id="std_pool"> +<a class="viewcode-back" href="../../../../api/graphnet.models.components.pool.html#graphnet.models.components.pool.std_pool">[docs]</a> +<span class="k">def</span> <span class="nf">std_pool</span><span class="p">(</span> + <span class="n">cluster</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">transform</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sa">r</span><span class="sd">"""Pool and coarsen graph according to the clustering defined in `cluster`.</span> + +<span class="sd"> All nodes within the same cluster will be represented as one node.</span> +<span class="sd"> Final node features are defined by the *std* of features of all nodes</span> +<span class="sd"> within the same cluster, node positions are averaged and edge indices are</span> +<span class="sd"> defined to be the union of the edge indices of all nodes within the same</span> +<span class="sd"> cluster.</span> + +<span class="sd"> Args:</span> +<span class="sd"> cluster: Cluster vector :math:`\mathbf{c} \in \{ 0,</span> +<span class="sd"> \ldots, N - 1 \}^N`, which assigns each node to a specific cluster.</span> +<span class="sd"> data: Graph data object.</span> +<span class="sd"> transform: A function/transform that takes in the</span> +<span class="sd"> coarsened and pooled :obj:`torch_geometric.data.Data` object and</span> +<span class="sd"> returns a transformed version.</span> +<span class="sd"> """</span> + <span class="n">cluster</span><span class="p">,</span> <span class="n">perm</span> <span class="o">=</span> <span class="n">consecutive_cluster</span><span class="p">(</span><span class="n">cluster</span><span class="p">)</span> + + <span class="n">x</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">_std_pool_x</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">)</span> + <span class="n">index</span><span class="p">,</span> <span class="n">attr</span> <span class="o">=</span> <span class="n">pool_edge</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_attr</span><span class="p">)</span> + <span class="n">batch</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">pool_batch</span><span class="p">(</span><span class="n">perm</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span><span class="p">)</span> + <span class="n">pos</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">pos</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">pool_pos</span><span class="p">(</span><span class="n">cluster</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">pos</span><span class="p">)</span> + + <span class="n">data</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">batch</span><span class="o">=</span><span class="n">batch</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="o">=</span><span class="n">index</span><span class="p">,</span> <span class="n">edge_attr</span><span class="o">=</span><span class="n">attr</span><span class="p">,</span> <span class="n">pos</span><span class="o">=</span><span class="n">pos</span><span class="p">)</span> + + <span class="k">if</span> <span class="n">transform</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">data</span> <span class="o">=</span> <span class="n">transform</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">data</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/detector/detector.html b/_modules/graphnet/models/detector/detector.html new file mode 100644 index 000000000..eef62aefc --- /dev/null +++ b/_modules/graphnet/models/detector/detector.html @@ -0,0 +1,421 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.detector.detector — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/detector/detector" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.detector.detector </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-detector-detector--page-root">Source code for graphnet.models.detector.detector</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Base detector-specific `Model` class(es)."""</span> + +<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">List</span> + +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> +<span class="kn">import</span> <span class="nn">torch</span> + +<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.decorators</span> <span class="kn">import</span> <span class="n">final</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span> + + +<div class="viewcode-block" id="Detector"> +<a class="viewcode-back" href="../../../../api/graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector">[docs]</a> +<span class="k">class</span> <span class="nc">Detector</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Base class for all detector-specific read-ins in graphnet."""</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct `Detector`."""</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span> + +<div class="viewcode-block" id="Detector.feature_map"> +<a class="viewcode-back" href="../../../../api/graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector.feature_map">[docs]</a> + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">feature_map</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""List of features used/assumed by inheriting `Detector` objects."""</span></div> + + +<div class="viewcode-block" id="Detector.forward"> +<a class="viewcode-back" href="../../../../api/graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector.forward">[docs]</a> + <span class="nd">@final</span> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span> <span class="c1"># type: ignore</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">node_features</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Pre-process graph `Data` features and build graph adjacency."""</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_standardize</span><span class="p">(</span><span class="n">node_features</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="p">)</span></div> + + + <span class="nd">@final</span> + <span class="k">def</span> <span class="nf">_standardize</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">node_features</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> + <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">feature</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">node_feature_names</span><span class="p">):</span> + <span class="k">try</span><span class="p">:</span> + <span class="n">node_features</span><span class="p">[:,</span> <span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feature_map</span><span class="p">()[</span><span class="n">feature</span><span class="p">](</span> <span class="c1"># type: ignore</span> + <span class="n">node_features</span><span class="p">[:,</span> <span class="n">idx</span><span class="p">]</span> + <span class="p">)</span> + <span class="k">except</span> <span class="ne">KeyError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span> + <span class="sa">f</span><span class="s2">"""No Standardization function found for '</span><span class="si">{</span><span class="n">feature</span><span class="si">}</span><span class="s2">'"""</span> + <span class="p">)</span> + <span class="k">raise</span> <span class="n">e</span> + <span class="k">return</span> <span class="n">node_features</span> + + <span class="k">def</span> <span class="nf">_identity</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Apply no standardization to input."""</span> + <span class="k">return</span> <span class="n">x</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/detector/icecube.html b/_modules/graphnet/models/detector/icecube.html new file mode 100644 index 000000000..76d9a8e16 --- /dev/null +++ b/_modules/graphnet/models/detector/icecube.html @@ -0,0 +1,528 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.detector.icecube — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/detector/icecube" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.detector.icecube </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-detector-icecube--page-root">Source code for graphnet.models.detector.icecube</h1><div class="highlight"><pre> +<span></span><span class="sd">"""IceCube-specific `Detector` class(es)."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Callable</span> +<span class="kn">import</span> <span class="nn">torch</span> + +<span class="kn">from</span> <span class="nn">graphnet.models.detector.detector</span> <span class="kn">import</span> <span class="n">Detector</span> + + +<div class="viewcode-block" id="IceCube86"> +<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCube86">[docs]</a> +<span class="k">class</span> <span class="nc">IceCube86</span><span class="p">(</span><span class="n">Detector</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""`Detector` class for IceCube-86."""</span> + +<div class="viewcode-block" id="IceCube86.feature_map"> +<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCube86.feature_map">[docs]</a> + <span class="k">def</span> <span class="nf">feature_map</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Map standardization functions to each dimension of input data."""</span> + <span class="n">feature_map</span> <span class="o">=</span> <span class="p">{</span> + <span class="s2">"dom_x"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xyz</span><span class="p">,</span> + <span class="s2">"dom_y"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xyz</span><span class="p">,</span> + <span class="s2">"dom_z"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xyz</span><span class="p">,</span> + <span class="s2">"dom_time"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_time</span><span class="p">,</span> + <span class="s2">"charge"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_charge</span><span class="p">,</span> + <span class="s2">"rde"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_rde</span><span class="p">,</span> + <span class="s2">"pmt_area"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pmt_area</span><span class="p">,</span> + <span class="p">}</span> + <span class="k">return</span> <span class="n">feature_map</span></div> + + + <span class="k">def</span> <span class="nf">_dom_xyz</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">500.0</span> + + <span class="k">def</span> <span class="nf">_dom_time</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="mf">1.0e04</span><span class="p">)</span> <span class="o">/</span> <span class="mf">3.0e4</span> + + <span class="k">def</span> <span class="nf">_charge</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">log10</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + + <span class="k">def</span> <span class="nf">_rde</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="mf">1.25</span><span class="p">)</span> <span class="o">/</span> <span class="mf">0.25</span> + + <span class="k">def</span> <span class="nf">_pmt_area</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">0.05</span></div> + + + +<div class="viewcode-block" id="IceCubeKaggle"> +<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeKaggle">[docs]</a> +<span class="k">class</span> <span class="nc">IceCubeKaggle</span><span class="p">(</span><span class="n">Detector</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""`Detector` class for Kaggle Competition."""</span> + +<div class="viewcode-block" id="IceCubeKaggle.feature_map"> +<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeKaggle.feature_map">[docs]</a> + <span class="k">def</span> <span class="nf">feature_map</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Map standardization functions to each dimension of input data."""</span> + <span class="n">feature_map</span> <span class="o">=</span> <span class="p">{</span> + <span class="s2">"x"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_xyz</span><span class="p">,</span> + <span class="s2">"y"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_xyz</span><span class="p">,</span> + <span class="s2">"z"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_xyz</span><span class="p">,</span> + <span class="s2">"time"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_time</span><span class="p">,</span> + <span class="s2">"charge"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_charge</span><span class="p">,</span> + <span class="s2">"auxiliary"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_identity</span><span class="p">,</span> + <span class="p">}</span> + <span class="k">return</span> <span class="n">feature_map</span></div> + + + <span class="k">def</span> <span class="nf">_xyz</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">500.0</span> + + <span class="k">def</span> <span class="nf">_time</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="mf">1.0e04</span><span class="p">)</span> <span class="o">/</span> <span class="mf">3.0e4</span> + + <span class="k">def</span> <span class="nf">_charge</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">log10</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="mf">3.0</span></div> + + + +<div class="viewcode-block" id="IceCubeDeepCore"> +<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeDeepCore">[docs]</a> +<span class="k">class</span> <span class="nc">IceCubeDeepCore</span><span class="p">(</span><span class="n">Detector</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""`Detector` class for IceCube-DeepCore."""</span> + +<div class="viewcode-block" id="IceCubeDeepCore.feature_map"> +<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeDeepCore.feature_map">[docs]</a> + <span class="k">def</span> <span class="nf">feature_map</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Map standardization functions to each dimension of input data."""</span> + <span class="n">feature_map</span> <span class="o">=</span> <span class="p">{</span> + <span class="s2">"dom_x"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xy</span><span class="p">,</span> + <span class="s2">"dom_y"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xy</span><span class="p">,</span> + <span class="s2">"dom_z"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_z</span><span class="p">,</span> + <span class="s2">"dom_time"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_time</span><span class="p">,</span> + <span class="s2">"charge"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_identity</span><span class="p">,</span> + <span class="s2">"rde"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_rde</span><span class="p">,</span> + <span class="s2">"pmt_area"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pmt_area</span><span class="p">,</span> + <span class="p">}</span> + <span class="k">return</span> <span class="n">feature_map</span></div> + + + <span class="k">def</span> <span class="nf">_dom_xy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">100.0</span> + + <span class="k">def</span> <span class="nf">_dom_z</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="mf">350.0</span><span class="p">)</span> <span class="o">/</span> <span class="mf">100.0</span> + + <span class="k">def</span> <span class="nf">_dom_time</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="p">((</span><span class="n">x</span> <span class="o">/</span> <span class="mf">1.05e04</span><span class="p">)</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">*</span> <span class="mf">20.0</span> + + <span class="k">def</span> <span class="nf">_rde</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="mf">1.25</span><span class="p">)</span> <span class="o">/</span> <span class="mf">0.25</span> + + <span class="k">def</span> <span class="nf">_pmt_area</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">0.05</span></div> + + + +<div class="viewcode-block" id="IceCubeUpgrade"> +<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeUpgrade">[docs]</a> +<span class="k">class</span> <span class="nc">IceCubeUpgrade</span><span class="p">(</span><span class="n">Detector</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""`Detector` class for IceCube-Upgrade."""</span> + +<div class="viewcode-block" id="IceCubeUpgrade.feature_map"> +<a class="viewcode-back" href="../../../../api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeUpgrade.feature_map">[docs]</a> + <span class="k">def</span> <span class="nf">feature_map</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Map standardization functions to each dimension of input data."""</span> + <span class="n">feature_map</span> <span class="o">=</span> <span class="p">{</span> + <span class="s2">"dom_x"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xyz</span><span class="p">,</span> + <span class="s2">"dom_y"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xyz</span><span class="p">,</span> + <span class="s2">"dom_z"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_xyz</span><span class="p">,</span> + <span class="s2">"dom_time"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_time</span><span class="p">,</span> + <span class="s2">"charge"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_charge</span><span class="p">,</span> + <span class="s2">"rde"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_identity</span><span class="p">,</span> + <span class="s2">"pmt_area"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pmt_area</span><span class="p">,</span> + <span class="s2">"string"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_string</span><span class="p">,</span> + <span class="s2">"pmt_number"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pmt_number</span><span class="p">,</span> + <span class="s2">"dom_number"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_number</span><span class="p">,</span> + <span class="s2">"pmt_dir_x"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_identity</span><span class="p">,</span> + <span class="s2">"pmt_dir_y"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_identity</span><span class="p">,</span> + <span class="s2">"pmt_dir_z"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_identity</span><span class="p">,</span> + <span class="s2">"dom_type"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dom_type</span><span class="p">,</span> + <span class="p">}</span> + + <span class="k">return</span> <span class="n">feature_map</span></div> + + + <span class="k">def</span> <span class="nf">_dom_time</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="mf">2e04</span><span class="p">)</span> <span class="o">-</span> <span class="mf">1.0</span> + + <span class="k">def</span> <span class="nf">_charge</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">log10</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="mf">2.0</span> + + <span class="k">def</span> <span class="nf">_string</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="mf">50.0</span><span class="p">)</span> <span class="o">/</span> <span class="mf">50.0</span> + + <span class="k">def</span> <span class="nf">_pmt_number</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">20.0</span> + + <span class="k">def</span> <span class="nf">_dom_number</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="mf">60.0</span><span class="p">)</span> <span class="o">/</span> <span class="mf">60.0</span> + + <span class="k">def</span> <span class="nf">_dom_type</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">130.0</span> + + <span class="k">def</span> <span class="nf">_dom_xyz</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">500.0</span> + + <span class="k">def</span> <span class="nf">_pmt_area</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">0.05</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/detector/prometheus.html b/_modules/graphnet/models/detector/prometheus.html new file mode 100644 index 000000000..7f234af38 --- /dev/null +++ b/_modules/graphnet/models/detector/prometheus.html @@ -0,0 +1,395 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.detector.prometheus — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/detector/prometheus" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.detector.prometheus </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-detector-prometheus--page-root">Source code for graphnet.models.detector.prometheus</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Prometheus-specific `Detector` class(es)."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Callable</span> +<span class="kn">import</span> <span class="nn">torch</span> + +<span class="kn">from</span> <span class="nn">graphnet.models.detector.detector</span> <span class="kn">import</span> <span class="n">Detector</span> + + +<div class="viewcode-block" id="Prometheus"> +<a class="viewcode-back" href="../../../../api/graphnet.models.detector.prometheus.html#graphnet.models.detector.prometheus.Prometheus">[docs]</a> +<span class="k">class</span> <span class="nc">Prometheus</span><span class="p">(</span><span class="n">Detector</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""`Detector` class for Prometheus prototype."""</span> + +<div class="viewcode-block" id="Prometheus.feature_map"> +<a class="viewcode-back" href="../../../../api/graphnet.models.detector.prometheus.html#graphnet.models.detector.prometheus.Prometheus.feature_map">[docs]</a> + <span class="k">def</span> <span class="nf">feature_map</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Map standardization functions to each dimension."""</span> + <span class="n">feature_map</span> <span class="o">=</span> <span class="p">{</span> + <span class="s2">"sensor_pos_x"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sensor_pos_xy</span><span class="p">,</span> + <span class="s2">"sensor_pos_y"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sensor_pos_xy</span><span class="p">,</span> + <span class="s2">"sensor_pos_z"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sensor_pos_z</span><span class="p">,</span> + <span class="s2">"t"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_t</span><span class="p">,</span> + <span class="p">}</span> + <span class="k">return</span> <span class="n">feature_map</span></div> + + + <span class="k">def</span> <span class="nf">_sensor_pos_xy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">x</span> <span class="o">/</span> <span class="mi">100</span> + + <span class="k">def</span> <span class="nf">_sensor_pos_z</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="mi">350</span><span class="p">)</span> <span class="o">/</span> <span class="mi">100</span> + + <span class="k">def</span> <span class="nf">_t</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="p">((</span><span class="n">x</span> <span class="o">/</span> <span class="mf">1.05e04</span><span class="p">)</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">*</span> <span class="mf">20.0</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/gnn/convnet.html b/_modules/graphnet/models/gnn/convnet.html new file mode 100644 index 000000000..ada586304 --- /dev/null +++ b/_modules/graphnet/models/gnn/convnet.html @@ -0,0 +1,486 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.gnn.convnet — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/gnn/convnet" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.gnn.convnet </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-gnn-convnet--page-root">Source code for graphnet.models.gnn.convnet</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Implementation of the ConvNet GNN model architecture.</span> + +<span class="sd">Author: Martin Ha Minh</span> +<span class="sd">"""</span> + +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span> +<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">BatchNorm1d</span><span class="p">,</span> <span class="n">Linear</span><span class="p">,</span> <span class="n">Dropout</span> +<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span> +<span class="kn">from</span> <span class="nn">torch_geometric.nn</span> <span class="kn">import</span> <span class="n">TAGConv</span><span class="p">,</span> <span class="n">global_add_pool</span><span class="p">,</span> <span class="n">global_max_pool</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span> +<span class="kn">from</span> <span class="nn">graphnet.models.gnn.gnn</span> <span class="kn">import</span> <span class="n">GNN</span> + + +<div class="viewcode-block" id="ConvNet"> +<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.convnet.html#graphnet.models.gnn.convnet.ConvNet">[docs]</a> +<span class="k">class</span> <span class="nc">ConvNet</span><span class="p">(</span><span class="n">GNN</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""ConvNet (convolutional network) model."""</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">nb_inputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> + <span class="n">nb_outputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> + <span class="n">nb_intermediate</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">128</span><span class="p">,</span> + <span class="n">dropout_ratio</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.3</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct `ConvNet`.</span> + +<span class="sd"> Args:</span> +<span class="sd"> nb_inputs: Number of input features, i.e. dimension of input</span> +<span class="sd"> layer.</span> +<span class="sd"> nb_outputs: Number of prediction labels, i.e. dimension of</span> +<span class="sd"> output layer.</span> +<span class="sd"> nb_intermediate: Number of nodes in intermediate layer(s).</span> +<span class="sd"> dropout_ratio: Fraction of nodes to drop.</span> +<span class="sd"> """</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">nb_inputs</span><span class="p">,</span> <span class="n">nb_outputs</span><span class="p">)</span> + + <span class="c1"># Member variables</span> + <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate</span> <span class="o">=</span> <span class="n">nb_intermediate</span> + <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span> <span class="o">=</span> <span class="mi">6</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate</span> + + <span class="c1"># Architecture configuration</span> + <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">TAGConv</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">TAGConv</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">conv3</span> <span class="o">=</span> <span class="n">TAGConv</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">batchnorm1</span> <span class="o">=</span> <span class="n">BatchNorm1d</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">linear1</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">linear2</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">linear3</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">linear4</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">linear5</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">drop1</span> <span class="o">=</span> <span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_ratio</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">drop2</span> <span class="o">=</span> <span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_ratio</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">drop3</span> <span class="o">=</span> <span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_ratio</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">drop4</span> <span class="o">=</span> <span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_ratio</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">drop5</span> <span class="o">=</span> <span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_ratio</span><span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">out</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nb_intermediate2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_outputs</span><span class="p">)</span> + +<div class="viewcode-block" id="ConvNet.forward"> +<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.convnet.html#graphnet.models.gnn.convnet.ConvNet.forward">[docs]</a> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Apply learnable forward pass."""</span> + <span class="c1"># Convenience variables</span> + <span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span> + + <span class="c1"># Graph convolutional operations</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">))</span> + <span class="n">x1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span> + <span class="p">[</span> + <span class="n">global_add_pool</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">),</span> + <span class="n">global_max_pool</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">),</span> + <span class="p">],</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">))</span> + <span class="n">x2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span> + <span class="p">[</span> + <span class="n">global_add_pool</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">),</span> + <span class="n">global_max_pool</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">),</span> + <span class="p">],</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv3</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">))</span> + <span class="n">x3</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span> + <span class="p">[</span> + <span class="n">global_add_pool</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">),</span> + <span class="n">global_max_pool</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">),</span> + <span class="p">],</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="c1"># Skip-cat</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">x1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">x3</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> + + <span class="c1"># Batch-normalising intermediate features</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batchnorm1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + + <span class="c1"># Post-processing</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linear1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linear2</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linear3</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop3</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linear4</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop4</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">linear5</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop5</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + + <span class="c1"># Read-out</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">out</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">x</span></div> +</div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/gnn/dynedge.html b/_modules/graphnet/models/gnn/dynedge.html new file mode 100644 index 000000000..d882546a3 --- /dev/null +++ b/_modules/graphnet/models/gnn/dynedge.html @@ -0,0 +1,693 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.gnn.dynedge — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/gnn/dynedge" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.gnn.dynedge </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-gnn-dynedge--page-root">Source code for graphnet.models.gnn.dynedge</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Implementation of the DynEdge GNN model architecture."""</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span> + +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">LongTensor</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> +<span class="kn">from</span> <span class="nn">torch_scatter</span> <span class="kn">import</span> <span class="n">scatter_max</span><span class="p">,</span> <span class="n">scatter_mean</span><span class="p">,</span> <span class="n">scatter_min</span><span class="p">,</span> <span class="n">scatter_sum</span> + +<span class="kn">from</span> <span class="nn">graphnet.models.components.layers</span> <span class="kn">import</span> <span class="n">DynEdgeConv</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span> +<span class="kn">from</span> <span class="nn">graphnet.models.gnn.gnn</span> <span class="kn">import</span> <span class="n">GNN</span> +<span class="kn">from</span> <span class="nn">graphnet.models.utils</span> <span class="kn">import</span> <span class="n">calculate_xyzt_homophily</span> + +<span class="n">GLOBAL_POOLINGS</span> <span class="o">=</span> <span class="p">{</span> + <span class="s2">"min"</span><span class="p">:</span> <span class="n">scatter_min</span><span class="p">,</span> + <span class="s2">"max"</span><span class="p">:</span> <span class="n">scatter_max</span><span class="p">,</span> + <span class="s2">"sum"</span><span class="p">:</span> <span class="n">scatter_sum</span><span class="p">,</span> + <span class="s2">"mean"</span><span class="p">:</span> <span class="n">scatter_mean</span><span class="p">,</span> +<span class="p">}</span> + + +<div class="viewcode-block" id="DynEdge"> +<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.dynedge.html#graphnet.models.gnn.dynedge.DynEdge">[docs]</a> +<span class="k">class</span> <span class="nc">DynEdge</span><span class="p">(</span><span class="n">GNN</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""DynEdge (dynamical edge convolutional) model."""</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">nb_inputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> + <span class="o">*</span><span class="p">,</span> + <span class="n">nb_neighbours</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span> + <span class="n">features_subset</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">slice</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">dynedge_layer_sizes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">post_processing_layer_sizes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">readout_layer_sizes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">global_pooling_schemes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">add_global_variables_after_pooling</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct `DynEdge`.</span> + +<span class="sd"> Args:</span> +<span class="sd"> nb_inputs: Number of input features on each node.</span> +<span class="sd"> nb_neighbours: Number of neighbours to used in the k-nearest</span> +<span class="sd"> neighbour clustering which is performed after each (dynamical)</span> +<span class="sd"> edge convolution.</span> +<span class="sd"> features_subset: The subset of latent features on each node that</span> +<span class="sd"> are used as metric dimensions when performing the k-nearest</span> +<span class="sd"> neighbours clustering. Defaults to [0,1,2].</span> +<span class="sd"> dynedge_layer_sizes: The layer sizes, or latent feature dimenions,</span> +<span class="sd"> used in the `DynEdgeConv` layer. Each entry in</span> +<span class="sd"> `dynedge_layer_sizes` corresponds to a single `DynEdgeConv`</span> +<span class="sd"> layer; the integers in the corresponding tuple corresponds to</span> +<span class="sd"> the layer sizes in the multi-layer perceptron (MLP) that is</span> +<span class="sd"> applied within each `DynEdgeConv` layer. That is, a list of</span> +<span class="sd"> size-two tuples means that all `DynEdgeConv` layers contain a</span> +<span class="sd"> two-layer MLP.</span> +<span class="sd"> Defaults to [(128, 256), (336, 256), (336, 256), (336, 256)].</span> +<span class="sd"> post_processing_layer_sizes: Hidden layer sizes in the MLP</span> +<span class="sd"> following the skip-concatenation of the outputs of each</span> +<span class="sd"> `DynEdgeConv` layer. Defaults to [336, 256].</span> +<span class="sd"> readout_layer_sizes: Hidden layer sizes in the MLP following the</span> +<span class="sd"> post-processing _and_ optional global pooling. As this is the</span> +<span class="sd"> last layer(s) in the model, the last layer in the read-out</span> +<span class="sd"> yields the output of the `DynEdge` model. Defaults to [128,].</span> +<span class="sd"> global_pooling_schemes: The list global pooling schemes to use.</span> +<span class="sd"> Options are: "min", "max", "mean", and "sum".</span> +<span class="sd"> add_global_variables_after_pooling: Whether to add global variables</span> +<span class="sd"> after global pooling. The alternative is to added (distribute)</span> +<span class="sd"> them to the individual nodes before any convolutional</span> +<span class="sd"> operations.</span> +<span class="sd"> """</span> + <span class="c1"># Latent feature subset for computing nearest neighbours in DynEdge.</span> + <span class="k">if</span> <span class="n">features_subset</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">features_subset</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> + + <span class="c1"># DynEdge layer sizes</span> + <span class="k">if</span> <span class="n">dynedge_layer_sizes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">dynedge_layer_sizes</span> <span class="o">=</span> <span class="p">[</span> + <span class="p">(</span> + <span class="mi">128</span><span class="p">,</span> + <span class="mi">256</span><span class="p">,</span> + <span class="p">),</span> + <span class="p">(</span> + <span class="mi">336</span><span class="p">,</span> + <span class="mi">256</span><span class="p">,</span> + <span class="p">),</span> + <span class="p">(</span> + <span class="mi">336</span><span class="p">,</span> + <span class="mi">256</span><span class="p">,</span> + <span class="p">),</span> + <span class="p">(</span> + <span class="mi">336</span><span class="p">,</span> + <span class="mi">256</span><span class="p">,</span> + <span class="p">),</span> + <span class="p">]</span> + + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dynedge_layer_sizes</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dynedge_layer_sizes</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="n">dynedge_layer_sizes</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">sizes</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span> <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="n">dynedge_layer_sizes</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span> + <span class="nb">all</span><span class="p">(</span><span class="n">size</span> <span class="o">></span> <span class="mi">0</span> <span class="k">for</span> <span class="n">size</span> <span class="ow">in</span> <span class="n">sizes</span><span class="p">)</span> <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="n">dynedge_layer_sizes</span> + <span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_dynedge_layer_sizes</span> <span class="o">=</span> <span class="n">dynedge_layer_sizes</span> + + <span class="c1"># Post-processing layer sizes</span> + <span class="k">if</span> <span class="n">post_processing_layer_sizes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">post_processing_layer_sizes</span> <span class="o">=</span> <span class="p">[</span> + <span class="mi">336</span><span class="p">,</span> + <span class="mi">256</span><span class="p">,</span> + <span class="p">]</span> + + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">post_processing_layer_sizes</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">post_processing_layer_sizes</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="n">size</span> <span class="o">></span> <span class="mi">0</span> <span class="k">for</span> <span class="n">size</span> <span class="ow">in</span> <span class="n">post_processing_layer_sizes</span><span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing_layer_sizes</span> <span class="o">=</span> <span class="n">post_processing_layer_sizes</span> + + <span class="c1"># Read-out layer sizes</span> + <span class="k">if</span> <span class="n">readout_layer_sizes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">readout_layer_sizes</span> <span class="o">=</span> <span class="p">[</span> + <span class="mi">128</span><span class="p">,</span> + <span class="p">]</span> + + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">readout_layer_sizes</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">readout_layer_sizes</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="n">size</span> <span class="o">></span> <span class="mi">0</span> <span class="k">for</span> <span class="n">size</span> <span class="ow">in</span> <span class="n">readout_layer_sizes</span><span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_readout_layer_sizes</span> <span class="o">=</span> <span class="n">readout_layer_sizes</span> + + <span class="c1"># Global pooling scheme(s)</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">global_pooling_schemes</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span> + <span class="n">global_pooling_schemes</span> <span class="o">=</span> <span class="p">[</span><span class="n">global_pooling_schemes</span><span class="p">]</span> + + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">global_pooling_schemes</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> + <span class="k">for</span> <span class="n">pooling_scheme</span> <span class="ow">in</span> <span class="n">global_pooling_schemes</span><span class="p">:</span> + <span class="k">assert</span> <span class="p">(</span> + <span class="n">pooling_scheme</span> <span class="ow">in</span> <span class="n">GLOBAL_POOLINGS</span> + <span class="p">),</span> <span class="sa">f</span><span class="s2">"Global pooling scheme </span><span class="si">{</span><span class="n">pooling_scheme</span><span class="si">}</span><span class="s2"> not supported."</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">assert</span> <span class="n">global_pooling_schemes</span> <span class="ow">is</span> <span class="kc">None</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span> <span class="o">=</span> <span class="n">global_pooling_schemes</span> + + <span class="k">if</span> <span class="n">add_global_variables_after_pooling</span><span class="p">:</span> + <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span><span class="p">,</span> <span class="p">(</span> + <span class="s2">"No global pooling schemes were request, so cannot add global"</span> + <span class="s2">" variables after pooling."</span> + <span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_add_global_variables_after_pooling</span> <span class="o">=</span> <span class="p">(</span> + <span class="n">add_global_variables_after_pooling</span> + <span class="p">)</span> + + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">nb_inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_readout_layer_sizes</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> + + <span class="c1"># Remaining member variables()</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_activation</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">()</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span> <span class="o">=</span> <span class="n">nb_inputs</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_nb_global_variables</span> <span class="o">=</span> <span class="mi">5</span> <span class="o">+</span> <span class="n">nb_inputs</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_nb_neighbours</span> <span class="o">=</span> <span class="n">nb_neighbours</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_features_subset</span> <span class="o">=</span> <span class="n">features_subset</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_construct_layers</span><span class="p">()</span> + + <span class="k">def</span> <span class="nf">_construct_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct layers (torch.nn.Modules)."""</span> + <span class="c1"># Convolutional operations</span> + <span class="n">nb_input_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span> + <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_global_variables_after_pooling</span><span class="p">:</span> + <span class="n">nb_input_features</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_global_variables</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_conv_layers</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">()</span> + <span class="n">nb_latent_features</span> <span class="o">=</span> <span class="n">nb_input_features</span> + <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dynedge_layer_sizes</span><span class="p">:</span> + <span class="n">layers</span> <span class="o">=</span> <span class="p">[]</span> + <span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">nb_latent_features</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="n">sizes</span><span class="p">)</span> + <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span> + <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span> + <span class="p">):</span> + <span class="k">if</span> <span class="n">ix</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> + <span class="n">nb_in</span> <span class="o">*=</span> <span class="mi">2</span> + <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">))</span> + <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_activation</span><span class="p">)</span> + + <span class="n">conv_layer</span> <span class="o">=</span> <span class="n">DynEdgeConv</span><span class="p">(</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">),</span> + <span class="n">aggr</span><span class="o">=</span><span class="s2">"add"</span><span class="p">,</span> + <span class="n">nb_neighbors</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_nb_neighbours</span><span class="p">,</span> + <span class="n">features_subset</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_features_subset</span><span class="p">,</span> + <span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_conv_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">conv_layer</span><span class="p">)</span> + + <span class="n">nb_latent_features</span> <span class="o">=</span> <span class="n">nb_out</span> + + <span class="c1"># Post-processing operations</span> + <span class="n">nb_latent_features</span> <span class="o">=</span> <span class="p">(</span> + <span class="nb">sum</span><span class="p">(</span><span class="n">sizes</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dynedge_layer_sizes</span><span class="p">)</span> + <span class="o">+</span> <span class="n">nb_input_features</span> + <span class="p">)</span> + + <span class="n">post_processing_layers</span> <span class="o">=</span> <span class="p">[]</span> + <span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">nb_latent_features</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing_layer_sizes</span> + <span class="p">)</span> + <span class="k">for</span> <span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:]):</span> + <span class="n">post_processing_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">))</span> + <span class="n">post_processing_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_activation</span><span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">post_processing_layers</span><span class="p">)</span> + + <span class="c1"># Read-out operations</span> + <span class="n">nb_poolings</span> <span class="o">=</span> <span class="p">(</span> + <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span><span class="p">)</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span> + <span class="k">else</span> <span class="mi">1</span> + <span class="p">)</span> + <span class="n">nb_latent_features</span> <span class="o">=</span> <span class="n">nb_out</span> <span class="o">*</span> <span class="n">nb_poolings</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_global_variables_after_pooling</span><span class="p">:</span> + <span class="n">nb_latent_features</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_global_variables</span> + + <span class="n">readout_layers</span> <span class="o">=</span> <span class="p">[]</span> + <span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">nb_latent_features</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_readout_layer_sizes</span><span class="p">)</span> + <span class="k">for</span> <span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:]):</span> + <span class="n">readout_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">))</span> + <span class="n">readout_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_activation</span><span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_readout</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">readout_layers</span><span class="p">)</span> + + <span class="k">def</span> <span class="nf">_global_pooling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Perform global pooling."""</span> + <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span> + <span class="n">pooled</span> <span class="o">=</span> <span class="p">[]</span> + <span class="k">for</span> <span class="n">pooling_scheme</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span><span class="p">:</span> + <span class="n">pooling_fn</span> <span class="o">=</span> <span class="n">GLOBAL_POOLINGS</span><span class="p">[</span><span class="n">pooling_scheme</span><span class="p">]</span> + <span class="n">pooled_x</span> <span class="o">=</span> <span class="n">pooling_fn</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">index</span><span class="o">=</span><span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pooled_x</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">pooled_x</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span> + <span class="c1"># `scatter_{min,max}`, which return also an argument, vs.</span> + <span class="c1"># `scatter_{mean,sum}`</span> + <span class="n">pooled_x</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">pooled_x</span> + <span class="n">pooled</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">pooled_x</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">pooled</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> + + <span class="k">def</span> <span class="nf">_calculate_global_variables</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> + <span class="n">edge_index</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> + <span class="n">batch</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> + <span class="o">*</span><span class="n">additional_attributes</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Calculate global variables."""</span> + <span class="c1"># Calculate homophily (scalar variables)</span> + <span class="n">h_x</span><span class="p">,</span> <span class="n">h_y</span><span class="p">,</span> <span class="n">h_z</span><span class="p">,</span> <span class="n">h_t</span> <span class="o">=</span> <span class="n">calculate_xyzt_homophily</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + + <span class="c1"># Calculate mean features</span> + <span class="n">global_means</span> <span class="o">=</span> <span class="n">scatter_mean</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> + + <span class="c1"># Add global variables</span> + <span class="n">global_variables</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span> + <span class="p">[</span> + <span class="n">global_means</span><span class="p">,</span> + <span class="n">h_x</span><span class="p">,</span> + <span class="n">h_y</span><span class="p">,</span> + <span class="n">h_z</span><span class="p">,</span> + <span class="n">h_t</span><span class="p">,</span> + <span class="p">]</span> + <span class="o">+</span> <span class="p">[</span><span class="n">attr</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">attr</span> <span class="ow">in</span> <span class="n">additional_attributes</span><span class="p">],</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="k">return</span> <span class="n">global_variables</span> + +<div class="viewcode-block" id="DynEdge.forward"> +<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.dynedge.html#graphnet.models.gnn.dynedge.DynEdge.forward">[docs]</a> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Apply learnable forward pass."""</span> + <span class="c1"># Convenience variables</span> + <span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span> + + <span class="n">global_variables</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_calculate_global_variables</span><span class="p">(</span> + <span class="n">x</span><span class="p">,</span> + <span class="n">edge_index</span><span class="p">,</span> + <span class="n">batch</span><span class="p">,</span> + <span class="n">torch</span><span class="o">.</span><span class="n">log10</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">n_pulses</span><span class="p">),</span> + <span class="p">)</span> + + <span class="c1"># Distribute global variables out to each node</span> + <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_global_variables_after_pooling</span><span class="p">:</span> + <span class="n">distribute</span> <span class="o">=</span> <span class="p">(</span> + <span class="n">batch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> + <span class="p">)</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span> + + <span class="n">global_variables_distributed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span> + <span class="n">distribute</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> + <span class="o">*</span> <span class="n">global_variables</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">global_variables_distributed</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> + + <span class="c1"># DynEdge-convolutions</span> + <span class="n">skip_connections</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">]</span> + <span class="k">for</span> <span class="n">conv_layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conv_layers</span><span class="p">:</span> + <span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span> <span class="o">=</span> <span class="n">conv_layer</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + <span class="n">skip_connections</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + + <span class="c1"># Skip-cat</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">skip_connections</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> + + <span class="c1"># Post-processing</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + + <span class="c1"># (Optional) Global pooling</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span><span class="p">:</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="o">=</span><span class="n">batch</span><span class="p">)</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_global_variables_after_pooling</span><span class="p">:</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span> + <span class="p">[</span> + <span class="n">x</span><span class="p">,</span> + <span class="n">global_variables</span><span class="p">,</span> + <span class="p">],</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="c1"># Read-out</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_readout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">x</span></div> +</div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/gnn/dynedge_jinst.html b/_modules/graphnet/models/gnn/dynedge_jinst.html new file mode 100644 index 000000000..125dc1399 --- /dev/null +++ b/_modules/graphnet/models/gnn/dynedge_jinst.html @@ -0,0 +1,521 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.gnn.dynedge_jinst — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/gnn/dynedge_jinst" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.gnn.dynedge_jinst </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-gnn-dynedge-jinst--page-root">Source code for graphnet.models.gnn.dynedge_jinst</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Implementation of the exact DynEdge architecture used in [2209.03042].</span> + +<span class="sd">Author: Rasmus Oersoe</span> +<span class="sd">"""</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span> + +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> +<span class="kn">from</span> <span class="nn">torch_scatter</span> <span class="kn">import</span> <span class="n">scatter_max</span><span class="p">,</span> <span class="n">scatter_mean</span><span class="p">,</span> <span class="n">scatter_min</span><span class="p">,</span> <span class="n">scatter_sum</span> + +<span class="kn">from</span> <span class="nn">graphnet.models.components.layers</span> <span class="kn">import</span> <span class="n">DynEdgeConv</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span> +<span class="kn">from</span> <span class="nn">graphnet.models.gnn.gnn</span> <span class="kn">import</span> <span class="n">GNN</span> +<span class="kn">from</span> <span class="nn">graphnet.models.utils</span> <span class="kn">import</span> <span class="n">calculate_xyzt_homophily</span> + + +<div class="viewcode-block" id="DynEdgeJINST"> +<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.dynedge_jinst.html#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST">[docs]</a> +<span class="k">class</span> <span class="nc">DynEdgeJINST</span><span class="p">(</span><span class="n">GNN</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""DynEdge (dynamical edge convolutional) model used in [2209.03042]."""</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">nb_inputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> + <span class="n">layer_size_scale</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct `DynEdgeJINST`.</span> + +<span class="sd"> Args:</span> +<span class="sd"> nb_inputs: Number of input features.</span> +<span class="sd"> nb_outputs: Number of output features.</span> +<span class="sd"> layer_size_scale: Integer that scales the size of hidden layers.</span> +<span class="sd"> """</span> + <span class="c1"># Architecture configuration</span> + <span class="n">c</span> <span class="o">=</span> <span class="n">layer_size_scale</span> + <span class="n">l1</span><span class="p">,</span> <span class="n">l2</span><span class="p">,</span> <span class="n">l3</span><span class="p">,</span> <span class="n">l4</span><span class="p">,</span> <span class="n">l5</span><span class="p">,</span> <span class="n">l6</span> <span class="o">=</span> <span class="p">(</span> + <span class="n">nb_inputs</span><span class="p">,</span> + <span class="n">c</span> <span class="o">*</span> <span class="mi">16</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> + <span class="n">c</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> + <span class="n">c</span> <span class="o">*</span> <span class="mi">42</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> + <span class="n">c</span> <span class="o">*</span> <span class="mi">32</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> + <span class="n">c</span> <span class="o">*</span> <span class="mi">16</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> + <span class="p">)</span> + + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">nb_inputs</span><span class="p">,</span> <span class="n">l6</span><span class="p">)</span> + + <span class="c1"># Graph convolutional operations</span> + <span class="n">features_subset</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> + <span class="n">nb_neighbors</span> <span class="o">=</span> <span class="mi">8</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">conv_add1</span> <span class="o">=</span> <span class="n">DynEdgeConv</span><span class="p">(</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l1</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">l2</span><span class="p">),</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l2</span><span class="p">,</span> <span class="n">l3</span><span class="p">),</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span> + <span class="p">),</span> + <span class="n">aggr</span><span class="o">=</span><span class="s2">"add"</span><span class="p">,</span> + <span class="n">nb_neighbors</span><span class="o">=</span><span class="n">nb_neighbors</span><span class="p">,</span> + <span class="n">features_subset</span><span class="o">=</span><span class="n">features_subset</span><span class="p">,</span> + <span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">conv_add2</span> <span class="o">=</span> <span class="n">DynEdgeConv</span><span class="p">(</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l3</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">l4</span><span class="p">),</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l4</span><span class="p">,</span> <span class="n">l3</span><span class="p">),</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span> + <span class="p">),</span> + <span class="n">aggr</span><span class="o">=</span><span class="s2">"add"</span><span class="p">,</span> + <span class="n">nb_neighbors</span><span class="o">=</span><span class="n">nb_neighbors</span><span class="p">,</span> + <span class="n">features_subset</span><span class="o">=</span><span class="n">features_subset</span><span class="p">,</span> + <span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">conv_add3</span> <span class="o">=</span> <span class="n">DynEdgeConv</span><span class="p">(</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l3</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">l4</span><span class="p">),</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l4</span><span class="p">,</span> <span class="n">l3</span><span class="p">),</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span> + <span class="p">),</span> + <span class="n">aggr</span><span class="o">=</span><span class="s2">"add"</span><span class="p">,</span> + <span class="n">nb_neighbors</span><span class="o">=</span><span class="n">nb_neighbors</span><span class="p">,</span> + <span class="n">features_subset</span><span class="o">=</span><span class="n">features_subset</span><span class="p">,</span> + <span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">conv_add4</span> <span class="o">=</span> <span class="n">DynEdgeConv</span><span class="p">(</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l3</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">l4</span><span class="p">),</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l4</span><span class="p">,</span> <span class="n">l3</span><span class="p">),</span> + <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(),</span> + <span class="p">),</span> + <span class="n">aggr</span><span class="o">=</span><span class="s2">"add"</span><span class="p">,</span> + <span class="n">nb_neighbors</span><span class="o">=</span><span class="n">nb_neighbors</span><span class="p">,</span> + <span class="n">features_subset</span><span class="o">=</span><span class="n">features_subset</span><span class="p">,</span> + <span class="p">)</span> + + <span class="c1"># Post-processing operations</span> + <span class="bp">self</span><span class="o">.</span><span class="n">nn1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l3</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="n">l1</span><span class="p">,</span> <span class="n">l4</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">nn2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">l4</span><span class="p">,</span> <span class="n">l5</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">nn3</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="n">l5</span> <span class="o">+</span> <span class="mi">5</span><span class="p">,</span> <span class="n">l6</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">lrelu</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">()</span> + +<div class="viewcode-block" id="DynEdgeJINST.forward"> +<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.dynedge_jinst.html#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward">[docs]</a> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Apply learnable forward pass."""</span> + <span class="c1"># Convenience variables</span> + <span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span> + + <span class="c1"># Calculate homophily (scalar variables)</span> + <span class="n">h_x</span><span class="p">,</span> <span class="n">h_y</span><span class="p">,</span> <span class="n">h_z</span><span class="p">,</span> <span class="n">h_t</span> <span class="o">=</span> <span class="n">calculate_xyzt_homophily</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + + <span class="n">a</span><span class="p">,</span> <span class="n">edge_index</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_add1</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + <span class="n">b</span><span class="p">,</span> <span class="n">edge_index</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_add2</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + <span class="n">c</span><span class="p">,</span> <span class="n">edge_index</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_add3</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + <span class="n">d</span><span class="p">,</span> <span class="n">edge_index</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_add4</span><span class="p">(</span><span class="n">c</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + + <span class="c1"># Skip-cat</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">d</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> + + <span class="c1"># Post-processing</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">nn1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lrelu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">nn2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + + <span class="c1"># Aggregation across nodes</span> + <span class="n">a</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">scatter_max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> + <span class="n">b</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">scatter_min</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> + <span class="n">c</span> <span class="o">=</span> <span class="n">scatter_sum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> + <span class="n">d</span> <span class="o">=</span> <span class="n">scatter_mean</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> + + <span class="c1"># Concatenate aggregations and scalar features</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span> + <span class="p">(</span> + <span class="n">a</span><span class="p">,</span> + <span class="n">b</span><span class="p">,</span> + <span class="n">c</span><span class="p">,</span> + <span class="n">d</span><span class="p">,</span> + <span class="n">h_t</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> + <span class="n">h_x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> + <span class="n">h_y</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> + <span class="n">h_z</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> + <span class="n">data</span><span class="o">.</span><span class="n">n_pulses</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> + <span class="p">),</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="c1"># Read-out</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lrelu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">nn3</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lrelu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">x</span></div> +</div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/gnn/dynedge_kaggle_tito.html b/_modules/graphnet/models/gnn/dynedge_kaggle_tito.html new file mode 100644 index 000000000..8895fa50c --- /dev/null +++ b/_modules/graphnet/models/gnn/dynedge_kaggle_tito.html @@ -0,0 +1,618 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.gnn.dynedge_kaggle_tito — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/gnn/dynedge_kaggle_tito" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.gnn.dynedge_kaggle_tito </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-gnn-dynedge-kaggle-tito--page-root">Source code for graphnet.models.gnn.dynedge_kaggle_tito</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Implementation of DynEdge architecture used in.</span> + +<span class="sd"> IceCube - Neutrinos in Deep Ice</span> +<span class="sd">Reconstruct the direction of neutrinos from the Universe to the South Pole</span> + +<span class="sd">Kaggle competition.</span> + +<span class="sd">Solution by TITO.</span> +<span class="sd">"""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Optional</span> + +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">LongTensor</span> + +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> +<span class="kn">from</span> <span class="nn">torch_geometric.utils</span> <span class="kn">import</span> <span class="n">to_dense_batch</span> +<span class="kn">from</span> <span class="nn">torch_scatter</span> <span class="kn">import</span> <span class="n">scatter_max</span><span class="p">,</span> <span class="n">scatter_mean</span><span class="p">,</span> <span class="n">scatter_min</span><span class="p">,</span> <span class="n">scatter_sum</span> + +<span class="kn">from</span> <span class="nn">graphnet.models.components.layers</span> <span class="kn">import</span> <span class="n">DynTrans</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span> +<span class="kn">from</span> <span class="nn">graphnet.models.gnn.gnn</span> <span class="kn">import</span> <span class="n">GNN</span> +<span class="kn">from</span> <span class="nn">graphnet.models.utils</span> <span class="kn">import</span> <span class="n">calculate_xyzt_homophily</span> + +<span class="n">GLOBAL_POOLINGS</span> <span class="o">=</span> <span class="p">{</span> + <span class="s2">"min"</span><span class="p">:</span> <span class="n">scatter_min</span><span class="p">,</span> + <span class="s2">"max"</span><span class="p">:</span> <span class="n">scatter_max</span><span class="p">,</span> + <span class="s2">"sum"</span><span class="p">:</span> <span class="n">scatter_sum</span><span class="p">,</span> + <span class="s2">"mean"</span><span class="p">:</span> <span class="n">scatter_mean</span><span class="p">,</span> +<span class="p">}</span> + + +<div class="viewcode-block" id="DynEdgeTITO"> +<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.dynedge_kaggle_tito.html#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO">[docs]</a> +<span class="k">class</span> <span class="nc">DynEdgeTITO</span><span class="p">(</span><span class="n">GNN</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""DynEdge (dynamical edge convolutional) model."""</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">nb_inputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> + <span class="n">features_subset</span><span class="p">:</span> <span class="nb">slice</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> + <span class="n">dyntrans_layer_sizes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">global_pooling_schemes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"max"</span><span class="p">],</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct `DynEdge`.</span> + +<span class="sd"> Args:</span> +<span class="sd"> nb_inputs: Number of input features on each node.</span> +<span class="sd"> features_subset: The subset of latent features on each node that</span> +<span class="sd"> are used as metric dimensions when performing the k-nearest</span> +<span class="sd"> neighbours clustering. Defaults to [0,1,2,3].</span> +<span class="sd"> dyntrans_layer_sizes: The layer sizes, or latent feature dimenions,</span> +<span class="sd"> used in the `DynTrans` layer.</span> +<span class="sd"> global_pooling_schemes: The list global pooling schemes to use.</span> +<span class="sd"> Options are: "min", "max", "mean", and "sum".</span> +<span class="sd"> """</span> + <span class="c1"># DynEdge layer sizes</span> + <span class="k">if</span> <span class="n">dyntrans_layer_sizes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">dyntrans_layer_sizes</span> <span class="o">=</span> <span class="p">[</span> + <span class="p">(</span> + <span class="mi">256</span><span class="p">,</span> + <span class="mi">256</span><span class="p">,</span> + <span class="p">),</span> + <span class="p">(</span> + <span class="mi">256</span><span class="p">,</span> + <span class="mi">256</span><span class="p">,</span> + <span class="p">),</span> + <span class="p">(</span> + <span class="mi">256</span><span class="p">,</span> + <span class="mi">256</span><span class="p">,</span> + <span class="p">),</span> + <span class="p">]</span> + + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dyntrans_layer_sizes</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dyntrans_layer_sizes</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="n">dyntrans_layer_sizes</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">sizes</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span> <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="n">dyntrans_layer_sizes</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span> + <span class="nb">all</span><span class="p">(</span><span class="n">size</span> <span class="o">></span> <span class="mi">0</span> <span class="k">for</span> <span class="n">size</span> <span class="ow">in</span> <span class="n">sizes</span><span class="p">)</span> <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="n">dyntrans_layer_sizes</span> + <span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_dyntrans_layer_sizes</span> <span class="o">=</span> <span class="n">dyntrans_layer_sizes</span> + + <span class="c1"># Post-processing layer sizes</span> + <span class="n">post_processing_layer_sizes</span> <span class="o">=</span> <span class="p">[</span> + <span class="mi">336</span><span class="p">,</span> + <span class="mi">256</span><span class="p">,</span> + <span class="p">]</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing_layer_sizes</span> <span class="o">=</span> <span class="n">post_processing_layer_sizes</span> + + <span class="c1"># Read-out layer sizes</span> + <span class="n">readout_layer_sizes</span> <span class="o">=</span> <span class="p">[</span> + <span class="mi">256</span><span class="p">,</span> + <span class="mi">128</span><span class="p">,</span> + <span class="p">]</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_readout_layer_sizes</span> <span class="o">=</span> <span class="n">readout_layer_sizes</span> + + <span class="c1"># Global pooling scheme(s)</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">global_pooling_schemes</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span> + <span class="n">global_pooling_schemes</span> <span class="o">=</span> <span class="p">[</span><span class="n">global_pooling_schemes</span><span class="p">]</span> + + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">global_pooling_schemes</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> + <span class="k">for</span> <span class="n">pooling_scheme</span> <span class="ow">in</span> <span class="n">global_pooling_schemes</span><span class="p">:</span> + <span class="k">assert</span> <span class="p">(</span> + <span class="n">pooling_scheme</span> <span class="ow">in</span> <span class="n">GLOBAL_POOLINGS</span> + <span class="p">),</span> <span class="sa">f</span><span class="s2">"Global pooling scheme </span><span class="si">{</span><span class="n">pooling_scheme</span><span class="si">}</span><span class="s2"> not supported."</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">assert</span> <span class="n">global_pooling_schemes</span> <span class="ow">is</span> <span class="kc">None</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span> <span class="o">=</span> <span class="n">global_pooling_schemes</span> + + <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span><span class="p">,</span> <span class="p">(</span> + <span class="s2">"No global pooling schemes were request, so cannot add global"</span> + <span class="s2">" variables after pooling."</span> + <span class="p">)</span> + + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">nb_inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_readout_layer_sizes</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> + + <span class="c1"># Remaining member variables()</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_activation</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">()</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span> <span class="o">=</span> <span class="n">nb_inputs</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_nb_global_variables</span> <span class="o">=</span> <span class="mi">5</span> <span class="o">+</span> <span class="n">nb_inputs</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_features_subset</span> <span class="o">=</span> <span class="n">features_subset</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_construct_layers</span><span class="p">()</span> + + <span class="k">def</span> <span class="nf">_construct_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct layers (torch.nn.Modules)."""</span> + <span class="c1"># Convolutional operations</span> + <span class="n">nb_input_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_conv_layers</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">()</span> + <span class="n">nb_latent_features</span> <span class="o">=</span> <span class="n">nb_input_features</span> + <span class="k">for</span> <span class="n">sizes</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dyntrans_layer_sizes</span><span class="p">:</span> + <span class="n">conv_layer</span> <span class="o">=</span> <span class="n">DynTrans</span><span class="p">(</span> + <span class="p">[</span><span class="n">nb_latent_features</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="n">sizes</span><span class="p">),</span> + <span class="n">aggr</span><span class="o">=</span><span class="s2">"max"</span><span class="p">,</span> + <span class="n">features_subset</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_features_subset</span><span class="p">,</span> + <span class="n">n_head</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> + <span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_conv_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">conv_layer</span><span class="p">)</span> + <span class="n">nb_latent_features</span> <span class="o">=</span> <span class="n">sizes</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> + + <span class="n">post_processing_layers</span> <span class="o">=</span> <span class="p">[]</span> + <span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">nb_latent_features</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing_layer_sizes</span> + <span class="p">)</span> + <span class="k">for</span> <span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:]):</span> + <span class="n">post_processing_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">))</span> + <span class="n">post_processing_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_activation</span><span class="p">)</span> + <span class="n">last_posting_layer_output_dim</span> <span class="o">=</span> <span class="n">nb_out</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">post_processing_layers</span><span class="p">)</span> + + <span class="c1"># Read-out operations</span> + <span class="n">nb_poolings</span> <span class="o">=</span> <span class="p">(</span> + <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span><span class="p">)</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span> + <span class="k">else</span> <span class="mi">1</span> + <span class="p">)</span> + <span class="n">nb_latent_features</span> <span class="o">=</span> <span class="n">last_posting_layer_output_dim</span> <span class="o">*</span> <span class="n">nb_poolings</span> + <span class="n">nb_latent_features</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_global_variables</span> + + <span class="n">readout_layers</span> <span class="o">=</span> <span class="p">[]</span> + <span class="n">layer_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">nb_latent_features</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_readout_layer_sizes</span><span class="p">)</span> + <span class="k">for</span> <span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">layer_sizes</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">layer_sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">:]):</span> + <span class="n">readout_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">nb_in</span><span class="p">,</span> <span class="n">nb_out</span><span class="p">))</span> + <span class="n">readout_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_activation</span><span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_readout</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">readout_layers</span><span class="p">)</span> + + <span class="k">def</span> <span class="nf">_global_pooling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Perform global pooling."""</span> + <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span> + <span class="n">pooled</span> <span class="o">=</span> <span class="p">[]</span> + <span class="k">for</span> <span class="n">pooling_scheme</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling_schemes</span><span class="p">:</span> + <span class="n">pooling_fn</span> <span class="o">=</span> <span class="n">GLOBAL_POOLINGS</span><span class="p">[</span><span class="n">pooling_scheme</span><span class="p">]</span> + <span class="n">pooled_x</span> <span class="o">=</span> <span class="n">pooling_fn</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">index</span><span class="o">=</span><span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pooled_x</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">pooled_x</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span> + <span class="c1"># `scatter_{min,max}`, which return also an argument, vs.</span> + <span class="c1"># `scatter_{mean,sum}`</span> + <span class="n">pooled_x</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">pooled_x</span> + <span class="n">pooled</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">pooled_x</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">pooled</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> + + <span class="k">def</span> <span class="nf">_calculate_global_variables</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> + <span class="n">edge_index</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> + <span class="n">batch</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> + <span class="o">*</span><span class="n">additional_attributes</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Calculate global variables."""</span> + <span class="c1"># Calculate homophily (scalar variables)</span> + <span class="n">h_x</span><span class="p">,</span> <span class="n">h_y</span><span class="p">,</span> <span class="n">h_z</span><span class="p">,</span> <span class="n">h_t</span> <span class="o">=</span> <span class="n">calculate_xyzt_homophily</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + + <span class="c1"># Calculate mean features</span> + <span class="n">global_means</span> <span class="o">=</span> <span class="n">scatter_mean</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> + + <span class="c1"># Add global variables</span> + <span class="n">global_variables</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span> + <span class="p">[</span> + <span class="n">global_means</span><span class="p">,</span> + <span class="n">h_x</span><span class="p">,</span> + <span class="n">h_y</span><span class="p">,</span> + <span class="n">h_z</span><span class="p">,</span> + <span class="n">h_t</span><span class="p">,</span> + <span class="p">]</span> + <span class="o">+</span> <span class="p">[</span><span class="n">attr</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">attr</span> <span class="ow">in</span> <span class="n">additional_attributes</span><span class="p">],</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="k">return</span> <span class="n">global_variables</span> + +<div class="viewcode-block" id="DynEdgeTITO.forward"> +<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.dynedge_kaggle_tito.html#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward">[docs]</a> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Apply learnable forward pass."""</span> + <span class="c1"># Convenience variables</span> + <span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">x</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">data</span><span class="o">.</span><span class="n">batch</span> + + <span class="n">global_variables</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_calculate_global_variables</span><span class="p">(</span> + <span class="n">x</span><span class="p">,</span> + <span class="n">edge_index</span><span class="p">,</span> + <span class="n">batch</span><span class="p">,</span> + <span class="n">torch</span><span class="o">.</span><span class="n">log10</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">n_pulses</span><span class="p">),</span> + <span class="p">)</span> + + <span class="c1"># DynEdge-convolutions</span> + <span class="k">for</span> <span class="n">conv_layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_conv_layers</span><span class="p">:</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">conv_layer</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + + <span class="n">x</span><span class="p">,</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">to_dense_batch</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span> + + <span class="c1"># Post-processing</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_post_processing</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + + <span class="c1"># (Optional) Global pooling</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_global_pooling</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">batch</span><span class="o">=</span><span class="n">batch</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span> + <span class="p">[</span> + <span class="n">x</span><span class="p">,</span> + <span class="n">global_variables</span><span class="p">,</span> + <span class="p">],</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="c1"># Read-out</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_readout</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">x</span></div> +</div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/gnn/gnn.html b/_modules/graphnet/models/gnn/gnn.html new file mode 100644 index 000000000..7a3a22f46 --- /dev/null +++ b/_modules/graphnet/models/gnn/gnn.html @@ -0,0 +1,403 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.gnn.gnn — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/gnn/gnn" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.gnn.gnn </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-gnn-gnn--page-root">Source code for graphnet.models.gnn.gnn</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Base GNN-specific `Model` class(es)."""</span> + +<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span> + +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> + +<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span> + + +<div class="viewcode-block" id="GNN"> +<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN">[docs]</a> +<span class="k">class</span> <span class="nc">GNN</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Base class for all core GNN models in graphnet."""</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">nb_inputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">nb_outputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct `GNN`."""</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> + + <span class="c1"># Member variables</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span> <span class="o">=</span> <span class="n">nb_inputs</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_nb_outputs</span> <span class="o">=</span> <span class="n">nb_outputs</span> + + <span class="nd">@property</span> + <span class="k">def</span> <span class="nf">nb_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return number of input features."""</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span> + + <span class="nd">@property</span> + <span class="k">def</span> <span class="nf">nb_outputs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return number of output features."""</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_outputs</span> + +<div class="viewcode-block" id="GNN.forward"> +<a class="viewcode-back" href="../../../../api/graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN.forward">[docs]</a> + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Apply learnable forward pass in model."""</span></div> +</div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/graphs/edges/edges.html b/_modules/graphnet/models/graphs/edges/edges.html new file mode 100644 index 000000000..e681629cf --- /dev/null +++ b/_modules/graphnet/models/graphs/edges/edges.html @@ -0,0 +1,563 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.graphs.edges.edges — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../../_static/material.css?v=79c92029" /> + <script src="../../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../../about.html" /> + <link rel="index" title="Index" href="../../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/graphs/edges/edges" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.graphs.edges.edges </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../../"versions.json"", + target_loc = "../../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-graphs-edges-edges--page-root">Source code for graphnet.models.graphs.edges.edges</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Class(es) for building/connecting graphs."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span> +<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span><span class="p">,</span> <span class="n">ABC</span> + +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch_geometric.nn</span> <span class="kn">import</span> <span class="n">knn_graph</span><span class="p">,</span> <span class="n">radius_graph</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span> +<span class="kn">from</span> <span class="nn">graphnet.models.utils</span> <span class="kn">import</span> <span class="n">calculate_distance_matrix</span> +<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span> + + +<div class="viewcode-block" id="EdgeDefinition"> +<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EdgeDefinition">[docs]</a> +<span class="k">class</span> <span class="nc">EdgeDefinition</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span> <span class="c1"># pylint: disable=too-few-public-methods</span> +<span class="w"> </span><span class="sd">"""Base class for graph building."""</span> + +<div class="viewcode-block" id="EdgeDefinition.forward"> +<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EdgeDefinition.forward">[docs]</a> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct edges based on problem specific implementation of.</span> + +<span class="sd"> ´_construct_edges´</span> + +<span class="sd"> Args:</span> +<span class="sd"> graph: a graph without edges</span> + +<span class="sd"> Returns:</span> +<span class="sd"> graph: a graph with edges</span> +<span class="sd"> """</span> + <span class="k">if</span> <span class="n">graph</span><span class="o">.</span><span class="n">edge_index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warnonce</span><span class="p">(</span> + <span class="s2">"GraphBuilder received graph with pre-existing "</span> + <span class="s2">"structure. Will overwrite."</span> + <span class="p">)</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_construct_edges</span><span class="p">(</span><span class="n">graph</span><span class="p">)</span></div> + + + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">_construct_edges</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct edges and assign them to graph. I.e. ´graph.edge_index = edge_index´.</span> + +<span class="sd"> Args:</span> +<span class="sd"> graph: graph without edges</span> + +<span class="sd"> Returns:</span> +<span class="sd"> graph: graph with edges assigned.</span> +<span class="sd"> """</span></div> + + + +<div class="viewcode-block" id="KNNEdges"> +<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.KNNEdges">[docs]</a> +<span class="k">class</span> <span class="nc">KNNEdges</span><span class="p">(</span><span class="n">EdgeDefinition</span><span class="p">):</span> <span class="c1"># pylint: disable=too-few-public-methods</span> +<span class="w"> </span><span class="sd">"""Builds edges from the k-nearest neighbours."""</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">nb_nearest_neighbours</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> + <span class="n">columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""K-NN Edge definition.</span> + +<span class="sd"> Will connect nodes together with their ´nb_nearest_neighbours´</span> +<span class="sd"> nearest neighbours in the feature space given by ´columns´.</span> + +<span class="sd"> Args:</span> +<span class="sd"> nb_nearest_neighbours: number of neighbours.</span> +<span class="sd"> columns: Node features to use for distance calculation.</span> +<span class="sd"> Defaults to [0,1,2].</span> +<span class="sd"> """</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span> + + <span class="c1"># Member variable(s)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_nb_nearest_neighbours</span> <span class="o">=</span> <span class="n">nb_nearest_neighbours</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_columns</span> <span class="o">=</span> <span class="n">columns</span> + + <span class="k">def</span> <span class="nf">_construct_edges</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Define K-NN edges."""</span> + <span class="n">graph</span><span class="o">.</span><span class="n">edge_index</span> <span class="o">=</span> <span class="n">knn_graph</span><span class="p">(</span> + <span class="n">graph</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_columns</span><span class="p">],</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_nb_nearest_neighbours</span><span class="p">,</span> + <span class="n">graph</span><span class="o">.</span><span class="n">batch</span><span class="p">,</span> + <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">graph</span></div> + + + +<div class="viewcode-block" id="RadialEdges"> +<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.RadialEdges">[docs]</a> +<span class="k">class</span> <span class="nc">RadialEdges</span><span class="p">(</span><span class="n">EdgeDefinition</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Builds graph from a sphere of chosen radius centred at each node."""</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">radius</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> + <span class="n">columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Radial edges.</span> + +<span class="sd"> Connects each node to other nodes that are within a sphere of</span> +<span class="sd"> radius ´r´ centered at the node. The feature space of ´r´ is defined</span> +<span class="sd"> by ´columns´</span> + +<span class="sd"> Args:</span> +<span class="sd"> radius: radius of sphere</span> +<span class="sd"> columns: columns of the node feature matrix used.</span> +<span class="sd"> Defaults to [0,1,2].</span> +<span class="sd"> """</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span> + + <span class="c1"># Member variable(s)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_radius</span> <span class="o">=</span> <span class="n">radius</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_columns</span> <span class="o">=</span> <span class="n">columns</span> + + <span class="k">def</span> <span class="nf">_construct_edges</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Define radial edges."""</span> + <span class="n">graph</span><span class="o">.</span><span class="n">edge_index</span> <span class="o">=</span> <span class="n">radius_graph</span><span class="p">(</span> + <span class="n">graph</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_columns</span><span class="p">],</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_radius</span><span class="p">,</span> + <span class="n">graph</span><span class="o">.</span><span class="n">batch</span><span class="p">,</span> + <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">graph</span></div> + + + +<div class="viewcode-block" id="EuclideanEdges"> +<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EuclideanEdges">[docs]</a> +<span class="k">class</span> <span class="nc">EuclideanEdges</span><span class="p">(</span><span class="n">EdgeDefinition</span><span class="p">):</span> <span class="c1"># pylint: disable=too-few-public-methods</span> +<span class="w"> </span><span class="sd">"""Builds edges according to Euclidean distance between nodes.</span> + +<span class="sd"> See https://arxiv.org/pdf/1809.06166.pdf.</span> +<span class="sd"> """</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">sigma</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> + <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> + <span class="n">columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct `EuclideanEdges`."""</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span> + + <span class="c1"># Check(s)</span> + <span class="k">if</span> <span class="n">columns</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">columns</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> + + <span class="c1"># Member variable(s)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_sigma</span> <span class="o">=</span> <span class="n">sigma</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_threshold</span> <span class="o">=</span> <span class="n">threshold</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_columns</span> <span class="o">=</span> <span class="n">columns</span> + + <span class="k">def</span> <span class="nf">_construct_edges</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Forward pass."""</span> + <span class="c1"># Constructs the adjacency matrix from the raw, DOM-level data and</span> + <span class="c1"># returns this matrix</span> + <span class="k">if</span> <span class="n">graph</span><span class="o">.</span><span class="n">edge_index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span> + <span class="s2">"WARNING: GraphBuilder received graph with pre-existing "</span> + <span class="s2">"structure. Will overwrite."</span> + <span class="p">)</span> + + <span class="n">xyz_coords</span> <span class="o">=</span> <span class="n">graph</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_columns</span><span class="p">]</span> + + <span class="c1"># Construct block-diagonal matrix indicating whether pulses belong to</span> + <span class="c1"># the same event in the batch</span> + <span class="n">batch_mask</span> <span class="o">=</span> <span class="n">graph</span><span class="o">.</span><span class="n">batch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="n">graph</span><span class="o">.</span><span class="n">batch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">1</span> + <span class="p">)</span> + + <span class="n">distance_matrix</span> <span class="o">=</span> <span class="n">calculate_distance_matrix</span><span class="p">(</span><span class="n">xyz_coords</span><span class="p">)</span> + <span class="n">affinity_matrix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span> + <span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">distance_matrix</span><span class="o">**</span><span class="mi">2</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sigma</span><span class="o">**</span><span class="mi">2</span> + <span class="p">)</span> + + <span class="c1"># Use softmax to normalise all adjacencies to one for each node</span> + <span class="n">exp_row_sums</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">affinity_matrix</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> + <span class="n">weighted_adj_matrix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span> + <span class="n">affinity_matrix</span> + <span class="p">)</span> <span class="o">/</span> <span class="n">exp_row_sums</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> + + <span class="c1"># Only include edges with weights that exceed the chosen threshold (and</span> + <span class="c1"># are part of the same event)</span> + <span class="n">sources</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span> + <span class="p">(</span><span class="n">weighted_adj_matrix</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">_threshold</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">batch_mask</span><span class="p">)</span> + <span class="p">)</span> + <span class="n">edge_weights</span> <span class="o">=</span> <span class="n">weighted_adj_matrix</span><span class="p">[</span><span class="n">sources</span><span class="p">,</span> <span class="n">targets</span><span class="p">]</span> + + <span class="n">graph</span><span class="o">.</span><span class="n">edge_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">sources</span><span class="p">,</span> <span class="n">targets</span><span class="p">))</span> + <span class="n">graph</span><span class="o">.</span><span class="n">edge_weight</span> <span class="o">=</span> <span class="n">edge_weights</span> + + <span class="k">return</span> <span class="n">graph</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/graphs/graph_definition.html b/_modules/graphnet/models/graphs/graph_definition.html new file mode 100644 index 000000000..75075b762 --- /dev/null +++ b/_modules/graphnet/models/graphs/graph_definition.html @@ -0,0 +1,634 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.graphs.graph_definition — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/graphs/graph_definition" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.graphs.graph_definition </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-graphs-graph-definition--page-root">Source code for graphnet.models.graphs.graph_definition</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Modules for defining graphs.</span> + +<span class="sd">These are self-contained graph definitions that hold all the graph-altering</span> +<span class="sd">code in graphnet. These modules define what the GNNs sees as input and can be</span> +<span class="sd">passed to dataloaders during training and deployment.</span> +<span class="sd">"""</span> + + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Callable</span> +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> +<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span> + +<span class="kn">from</span> <span class="nn">graphnet.models.detector</span> <span class="kn">import</span> <span class="n">Detector</span> +<span class="kn">from</span> <span class="nn">.edges</span> <span class="kn">import</span> <span class="n">EdgeDefinition</span> +<span class="kn">from</span> <span class="nn">.nodes</span> <span class="kn">import</span> <span class="n">NodeDefinition</span> +<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span> + + +<div class="viewcode-block" id="GraphDefinition"> +<a class="viewcode-back" href="../../../../api/graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition">[docs]</a> +<span class="k">class</span> <span class="nc">GraphDefinition</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""An Abstract class to create graph definitions from."""</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">detector</span><span class="p">:</span> <span class="n">Detector</span><span class="p">,</span> + <span class="n">node_definition</span><span class="p">:</span> <span class="n">NodeDefinition</span><span class="p">,</span> + <span class="n">edge_definition</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">EdgeDefinition</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct ´GraphDefinition´. The ´detector´ holds.</span> + +<span class="sd"> ´Detector´-specific code. E.g. scaling/standardization and geometry</span> +<span class="sd"> tables.</span> + +<span class="sd"> ´node_definition´ defines the nodes in the graph.</span> + +<span class="sd"> ´edge_definition´ defines the connectivity of the nodes in the graph.</span> + +<span class="sd"> Args:</span> +<span class="sd"> detector: The corresponding ´Detector´ representing the data.</span> +<span class="sd"> node_definition: Definition of nodes.</span> +<span class="sd"> edge_definition: Definition of edges. Defaults to None.</span> +<span class="sd"> node_feature_names: Names of node feature columns. Defaults to None</span> +<span class="sd"> dtype: data type used for node features. e.g. ´torch.float´</span> +<span class="sd"> """</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span> + + <span class="c1"># Member Variables</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_detector</span> <span class="o">=</span> <span class="n">detector</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_edge_definition</span> <span class="o">=</span> <span class="n">edge_definition</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_node_definition</span> <span class="o">=</span> <span class="n">node_definition</span> + <span class="k">if</span> <span class="n">node_feature_names</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="c1"># Assume all features in Detector is used.</span> + <span class="n">node_feature_names</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_detector</span><span class="o">.</span><span class="n">feature_map</span><span class="p">()</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="c1"># type: ignore</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_node_feature_names</span> <span class="o">=</span> <span class="n">node_feature_names</span> + + <span class="c1"># Set data type</span> + <span class="bp">self</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span> + + <span class="c1"># Set Input / Output dimensions</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_node_definition</span><span class="o">.</span><span class="n">set_number_of_inputs</span><span class="p">(</span> + <span class="n">node_feature_names</span><span class="o">=</span><span class="n">node_feature_names</span> + <span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">nb_inputs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_node_feature_names</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">nb_outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_node_definition</span><span class="o">.</span><span class="n">nb_outputs</span> + +<div class="viewcode-block" id="GraphDefinition.forward"> +<a class="viewcode-back" href="../../../../api/graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition.forward">[docs]</a> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span> <span class="c1"># type: ignore</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">node_features</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> + <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="n">truth_dicts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">custom_label_functions</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight_column</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight_default_value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">data_path</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct graph as ´Data´ object.</span> + +<span class="sd"> Args:</span> +<span class="sd"> node_features: node features for graph. Shape ´[num_nodes, d]´</span> +<span class="sd"> node_feature_names: name of each column. Shape ´[,d]´.</span> +<span class="sd"> truth_dicts: Dictionary containing truth labels.</span> +<span class="sd"> custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels.</span> +<span class="sd"> loss_weight_column: Name of column that holds loss weight. Defaults to None.</span> +<span class="sd"> loss_weight: Loss weight associated with event. Defaults to None.</span> +<span class="sd"> loss_weight_default_value: default value for loss weight. Used in instances where some events have no pre-defined loss weight. Defaults to None.</span> +<span class="sd"> data_path: Path to dataset data files. Defaults to None.</span> + +<span class="sd"> Returns:</span> +<span class="sd"> graph</span> +<span class="sd"> """</span> + <span class="c1"># Checks</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_validate_input</span><span class="p">(</span> + <span class="n">node_features</span><span class="o">=</span><span class="n">node_features</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="o">=</span><span class="n">node_feature_names</span> + <span class="p">)</span> + + <span class="c1"># Transform to pytorch tensor</span> + <span class="n">node_features</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">node_features</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> + + <span class="c1"># Standardize / Scale node features</span> + <span class="n">node_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_detector</span><span class="p">(</span><span class="n">node_features</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="p">)</span> + + <span class="c1"># Create graph</span> + <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_node_definition</span><span class="p">(</span><span class="n">node_features</span><span class="p">)</span> + + <span class="c1"># Attach number of pulses as static attribute.</span> + <span class="n">graph</span><span class="o">.</span><span class="n">n_pulses</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">node_features</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> + + <span class="c1"># Assign edges</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_edge_definition</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_edge_definition</span><span class="p">(</span><span class="n">graph</span><span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warnonce</span><span class="p">(</span> + <span class="s2">"No EdgeDefinition provided. Graphs will not have edges defined!"</span> + <span class="p">)</span> + + <span class="c1"># Attach data path - useful for Ensemble datasets.</span> + <span class="k">if</span> <span class="n">data_path</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">graph</span><span class="p">[</span><span class="s2">"dataset_path"</span><span class="p">]</span> <span class="o">=</span> <span class="n">data_path</span> + + <span class="c1"># Attach loss weights if they exist</span> + <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_loss_weights</span><span class="p">(</span> + <span class="n">graph</span><span class="o">=</span><span class="n">graph</span><span class="p">,</span> + <span class="n">loss_weight</span><span class="o">=</span><span class="n">loss_weight</span><span class="p">,</span> + <span class="n">loss_weight_column</span><span class="o">=</span><span class="n">loss_weight_column</span><span class="p">,</span> + <span class="n">loss_weight_default_value</span><span class="o">=</span><span class="n">loss_weight_default_value</span><span class="p">,</span> + <span class="p">)</span> + + <span class="c1"># Attach default truth labels and node truths</span> + <span class="k">if</span> <span class="n">truth_dicts</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_truth</span><span class="p">(</span><span class="n">graph</span><span class="o">=</span><span class="n">graph</span><span class="p">,</span> <span class="n">truth_dicts</span><span class="o">=</span><span class="n">truth_dicts</span><span class="p">)</span> + + <span class="c1"># Attach custom truth labels</span> + <span class="k">if</span> <span class="n">custom_label_functions</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_custom_labels</span><span class="p">(</span> + <span class="n">graph</span><span class="o">=</span><span class="n">graph</span><span class="p">,</span> <span class="n">custom_label_functions</span><span class="o">=</span><span class="n">custom_label_functions</span> + <span class="p">)</span> + + <span class="c1"># Attach node features as seperate fields. MAY NOT CONTAIN 'x'</span> + <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_features_individually</span><span class="p">(</span> + <span class="n">graph</span><span class="o">=</span><span class="n">graph</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="o">=</span><span class="n">node_feature_names</span> + <span class="p">)</span> + + <span class="c1"># Add GraphDefinition Stamp</span> + <span class="n">graph</span><span class="p">[</span><span class="s2">"graph_definition"</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> + <span class="k">return</span> <span class="n">graph</span></div> + + + <span class="k">def</span> <span class="nf">_validate_input</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">node_features</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> + <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> + <span class="c1"># node feature matrix dimension check</span> + <span class="k">assert</span> <span class="n">node_features</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">node_feature_names</span><span class="p">)</span> + + <span class="c1"># check that provided features for input is the same that the ´Graph´</span> + <span class="c1"># was instantiated with.</span> + <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">node_feature_names</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_node_feature_names</span> + <span class="p">),</span> <span class="sa">f</span><span class="s2">"""Input features (</span><span class="si">{</span><span class="n">node_feature_names</span><span class="si">}</span><span class="s2">) is not what </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> was instatiated with (</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_node_feature_names</span><span class="si">}</span><span class="s2">)"""</span> + <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">node_feature_names</span><span class="p">)):</span> + <span class="k">assert</span> <span class="p">(</span> + <span class="n">node_feature_names</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">_node_feature_names</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> + <span class="p">),</span> <span class="sa">f</span><span class="s2">""" Order of node features in data are not the same as expected. Got </span><span class="si">{</span><span class="n">node_feature_names</span><span class="si">}</span><span class="s2"> vs. </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_node_feature_names</span><span class="si">}</span><span class="s2">"""</span> + + <span class="k">def</span> <span class="nf">_add_loss_weights</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> + <span class="n">loss_weight_column</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight_default_value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Attempt to store a loss weight in the graph for use during training.</span> + +<span class="sd"> I.e. `graph[loss_weight_column] = loss_weight`</span> + +<span class="sd"> Args:</span> +<span class="sd"> loss_weight: The non-negative weight to be stored.</span> +<span class="sd"> graph: Data object representing the event.</span> +<span class="sd"> loss_weight_column: The name under which the weight is stored in</span> +<span class="sd"> the graph.</span> +<span class="sd"> loss_weight_default_value: The default value used if</span> +<span class="sd"> none was retrieved.</span> + +<span class="sd"> Returns:</span> +<span class="sd"> A graph with loss weight added, if available.</span> +<span class="sd"> """</span> + <span class="c1"># Add loss weight to graph.</span> + <span class="k">if</span> <span class="n">loss_weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">loss_weight_column</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="c1"># No loss weight was retrieved, i.e., it is missing for the current</span> + <span class="c1"># event.</span> + <span class="k">if</span> <span class="n">loss_weight</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span> + <span class="k">if</span> <span class="n">loss_weight_default_value</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> + <span class="s2">"At least one event is missing an entry in "</span> + <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">loss_weight_column</span><span class="si">}</span><span class="s2"> "</span> + <span class="s2">"but loss_weight_default_value is None."</span> + <span class="p">)</span> + <span class="n">graph</span><span class="p">[</span><span class="n">loss_weight_column</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight_default_value</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> + <span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">graph</span><span class="p">[</span><span class="n">loss_weight_column</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span> + <span class="n">loss_weight</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> + <span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> + <span class="k">return</span> <span class="n">graph</span> + + <span class="k">def</span> <span class="nf">_add_truth</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">truth_dicts</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Add truth labels from ´truth_dicts´ to ´graph´.</span> + +<span class="sd"> I.e. ´graph[key] = truth_dict[key]´</span> + + +<span class="sd"> Args:</span> +<span class="sd"> graph: graph where the label will be stored</span> +<span class="sd"> truth_dicts: dictionary containing the labels</span> + +<span class="sd"> Returns:</span> +<span class="sd"> graph with labels</span> +<span class="sd"> """</span> + <span class="c1"># Write attributes, either target labels, truth info or original</span> + <span class="c1"># features.</span> + <span class="k">for</span> <span class="n">truth_dict</span> <span class="ow">in</span> <span class="n">truth_dicts</span><span class="p">:</span> + <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">truth_dict</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> + <span class="k">try</span><span class="p">:</span> + <span class="n">graph</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">value</span><span class="p">)</span> + <span class="k">except</span> <span class="ne">TypeError</span><span class="p">:</span> + <span class="c1"># Cannot convert `value` to Tensor due to its data type,</span> + <span class="c1"># e.g. `str`.</span> + <span class="bp">self</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span> + <span class="p">(</span> + <span class="sa">f</span><span class="s2">"Could not assign `</span><span class="si">{</span><span class="n">key</span><span class="si">}</span><span class="s2">` with type "</span> + <span class="sa">f</span><span class="s2">"'</span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">value</span><span class="p">)</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">' as attribute to graph."</span> + <span class="p">)</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">graph</span> + + <span class="k">def</span> <span class="nf">_add_features_individually</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> + <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> + <span class="c1"># Additionally add original features as (static) attributes</span> + <span class="n">graph</span><span class="o">.</span><span class="n">features</span> <span class="o">=</span> <span class="n">node_feature_names</span> + <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">feature</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">node_feature_names</span><span class="p">):</span> + <span class="k">if</span> <span class="n">feature</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">"x"</span><span class="p">]:</span> <span class="c1"># reserved for node features.</span> + <span class="n">graph</span><span class="p">[</span><span class="n">feature</span><span class="p">]</span> <span class="o">=</span> <span class="n">graph</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="n">index</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span> + <span class="k">else</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warnonce</span><span class="p">(</span> +<span class="w"> </span><span class="sd">"""Cannot assign graph['x']. This field is reserved for node features. Please rename your input feature."""</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">graph</span> + + <span class="k">def</span> <span class="nf">_add_custom_labels</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> + <span class="n">custom_label_functions</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">Any</span><span class="p">]],</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> + <span class="c1"># Add custom labels to the graph</span> + <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">fn</span> <span class="ow">in</span> <span class="n">custom_label_functions</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> + <span class="n">graph</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">fn</span><span class="p">(</span><span class="n">graph</span><span class="p">)</span> + <span class="k">return</span> <span class="n">graph</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/graphs/graphs.html b/_modules/graphnet/models/graphs/graphs.html new file mode 100644 index 000000000..5f3dc4be7 --- /dev/null +++ b/_modules/graphnet/models/graphs/graphs.html @@ -0,0 +1,410 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.graphs.graphs — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/graphs/graphs" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.graphs.graphs </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-graphs-graphs--page-root">Source code for graphnet.models.graphs.graphs</h1><div class="highlight"><pre> +<span></span><span class="sd">"""A module containing different graph representations in GraphNeT."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span> +<span class="kn">import</span> <span class="nn">torch</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span> +<span class="kn">from</span> <span class="nn">.graph_definition</span> <span class="kn">import</span> <span class="n">GraphDefinition</span> +<span class="kn">from</span> <span class="nn">graphnet.models.detector</span> <span class="kn">import</span> <span class="n">Detector</span> +<span class="kn">from</span> <span class="nn">graphnet.models.graphs.edges</span> <span class="kn">import</span> <span class="n">EdgeDefinition</span><span class="p">,</span> <span class="n">KNNEdges</span> +<span class="kn">from</span> <span class="nn">graphnet.models.graphs.nodes</span> <span class="kn">import</span> <span class="n">NodeDefinition</span> + + +<div class="viewcode-block" id="KNNGraph"> +<a class="viewcode-back" href="../../../../api/graphnet.models.graphs.graphs.html#graphnet.models.graphs.graphs.KNNGraph">[docs]</a> +<span class="k">class</span> <span class="nc">KNNGraph</span><span class="p">(</span><span class="n">GraphDefinition</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""A Graph representation where Edges are drawn to nearest neighbours."""</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">detector</span><span class="p">:</span> <span class="n">Detector</span><span class="p">,</span> + <span class="n">node_definition</span><span class="p">:</span> <span class="n">NodeDefinition</span><span class="p">,</span> + <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span> + <span class="n">nb_nearest_neighbours</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span> + <span class="n">columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> + <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct k-nn graph representation.</span> + +<span class="sd"> Args:</span> +<span class="sd"> detector: Detector that represents your data.</span> +<span class="sd"> node_definition: Definition of nodes in the graph.</span> +<span class="sd"> node_feature_names: Name of node features.</span> +<span class="sd"> dtype: data type for node features.</span> +<span class="sd"> nb_nearest_neighbours: Number of edges for each node. Defaults to 8.</span> +<span class="sd"> columns: node feature columns used for distance calculation</span> +<span class="sd"> . Defaults to [0, 1, 2].</span> +<span class="sd"> """</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span> + <span class="n">detector</span><span class="o">=</span><span class="n">detector</span><span class="p">,</span> + <span class="n">node_definition</span><span class="o">=</span><span class="n">node_definition</span><span class="p">,</span> + <span class="n">edge_definition</span><span class="o">=</span><span class="n">KNNEdges</span><span class="p">(</span> + <span class="n">nb_nearest_neighbours</span><span class="o">=</span><span class="n">nb_nearest_neighbours</span><span class="p">,</span> + <span class="n">columns</span><span class="o">=</span><span class="n">columns</span><span class="p">,</span> + <span class="p">),</span> + <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> + <span class="n">node_feature_names</span><span class="o">=</span><span class="n">node_feature_names</span><span class="p">,</span> + <span class="p">)</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/graphs/nodes/nodes.html b/_modules/graphnet/models/graphs/nodes/nodes.html new file mode 100644 index 000000000..3e7040189 --- /dev/null +++ b/_modules/graphnet/models/graphs/nodes/nodes.html @@ -0,0 +1,444 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.graphs.nodes.nodes — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../../_static/material.css?v=79c92029" /> + <script src="../../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../../about.html" /> + <link rel="index" title="Index" href="../../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/graphs/nodes/nodes" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.graphs.nodes.nodes </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../../"versions.json"", + target_loc = "../../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-graphs-nodes-nodes--page-root">Source code for graphnet.models.graphs.nodes.nodes</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Class(es) for building/connecting graphs."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span> +<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span> + +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.decorators</span> <span class="kn">import</span> <span class="n">final</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span> +<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span> + + +<div class="viewcode-block" id="NodeDefinition"> +<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition">[docs]</a> +<span class="k">class</span> <span class="nc">NodeDefinition</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span> <span class="c1"># pylint: disable=too-few-public-methods</span> +<span class="w"> </span><span class="sd">"""Base class for graph building."""</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct `Detector`."""</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span> + +<div class="viewcode-block" id="NodeDefinition.forward"> +<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition.forward">[docs]</a> + <span class="nd">@final</span> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct nodes from raw node features.</span> + +<span class="sd"> Args:</span> +<span class="sd"> x: standardized node features with shape ´[num_pulses, d]´,</span> +<span class="sd"> where ´d´ is the number of node features.</span> + +<span class="sd"> Returns:</span> +<span class="sd"> graph: a graph without edges</span> +<span class="sd"> """</span> + <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_construct_nodes</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + <span class="k">return</span> <span class="n">graph</span></div> + + + <span class="nd">@property</span> + <span class="k">def</span> <span class="nf">nb_outputs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return number of output features.</span> + +<span class="sd"> This the default, but may be overridden by specific inheriting classes.</span> +<span class="sd"> """</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_inputs</span> + +<div class="viewcode-block" id="NodeDefinition.set_number_of_inputs"> +<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs">[docs]</a> + <span class="nd">@final</span> + <span class="k">def</span> <span class="nf">set_number_of_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">node_feature_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">])</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return number of inputs expected by node definition.</span> + +<span class="sd"> Args:</span> +<span class="sd"> node_feature_names: name of each node feature column.</span> +<span class="sd"> """</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">node_feature_names</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">nb_inputs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">node_feature_names</span><span class="p">)</span></div> + + + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">_construct_nodes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct nodes from raw node features ´x´.</span> + +<span class="sd"> Args:</span> +<span class="sd"> x: standardized node features with shape ´[num_pulses, d]´,</span> +<span class="sd"> where ´d´ is the number of node features.</span> + +<span class="sd"> Returns:</span> +<span class="sd"> graph: graph without edges.</span> +<span class="sd"> """</span></div> + + + +<div class="viewcode-block" id="NodesAsPulses"> +<a class="viewcode-back" href="../../../../../api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodesAsPulses">[docs]</a> +<span class="k">class</span> <span class="nc">NodesAsPulses</span><span class="p">(</span><span class="n">NodeDefinition</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Represent each measured pulse of Cherenkov Radiation as a node."""</span> + + <span class="k">def</span> <span class="nf">_construct_nodes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Data</span><span class="p">:</span> + <span class="k">return</span> <span class="n">Data</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">)</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/model.html b/_modules/graphnet/models/model.html new file mode 100644 index 000000000..9654c1e81 --- /dev/null +++ b/_modules/graphnet/models/model.html @@ -0,0 +1,726 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.model — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" /> + <script src="../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../about.html" /> + <link rel="index" title="Index" href="../../../genindex.html" /> + <link rel="search" title="Search" href="../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/model" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.model </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../"versions.json"", + target_loc = "../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-model--page-root">Source code for graphnet.models.model</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Base class(es) for building models."""</span> + +<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span> +<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">OrderedDict</span> +<span class="kn">import</span> <span class="nn">dill</span> +<span class="kn">import</span> <span class="nn">os.path</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span> + +<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> +<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> +<span class="kn">from</span> <span class="nn">pytorch_lightning</span> <span class="kn">import</span> <span class="n">Trainer</span><span class="p">,</span> <span class="n">LightningModule</span> +<span class="kn">from</span> <span class="nn">pytorch_lightning.callbacks.callback</span> <span class="kn">import</span> <span class="n">Callback</span> +<span class="kn">from</span> <span class="nn">pytorch_lightning.callbacks</span> <span class="kn">import</span> <span class="n">EarlyStopping</span> +<span class="kn">from</span> <span class="nn">pytorch_lightning.loggers.logger</span> <span class="kn">import</span> <span class="n">Logger</span> <span class="k">as</span> <span class="n">LightningLogger</span> +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span> +<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">SequentialSampler</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">Configurable</span><span class="p">,</span> <span class="n">ModelConfig</span> +<span class="kn">from</span> <span class="nn">graphnet.training.callbacks</span> <span class="kn">import</span> <span class="n">ProgressBar</span> + + +<div class="viewcode-block" id="Model"> +<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model">[docs]</a> +<span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">Logger</span><span class="p">,</span> <span class="n">Configurable</span><span class="p">,</span> <span class="n">LightningModule</span><span class="p">,</span> <span class="n">ABC</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Base class for all models in graphnet."""</span> + +<div class="viewcode-block" id="Model.forward"> +<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.forward">[docs]</a> + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">])</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Forward pass."""</span></div> + + + <span class="nd">@staticmethod</span> + <span class="k">def</span> <span class="nf">_construct_trainer</span><span class="p">(</span> + <span class="n">max_epochs</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> + <span class="n">gpus</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">callbacks</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Callback</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">ckpt_path</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">logger</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LightningLogger</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">log_every_n_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> + <span class="n">gradient_clip_val</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">distribution_strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"ddp"</span><span class="p">,</span> + <span class="o">**</span><span class="n">trainer_kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Trainer</span><span class="p">:</span> + + <span class="k">if</span> <span class="n">gpus</span><span class="p">:</span> + <span class="n">accelerator</span> <span class="o">=</span> <span class="s2">"gpu"</span> + <span class="n">devices</span> <span class="o">=</span> <span class="n">gpus</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">accelerator</span> <span class="o">=</span> <span class="s2">"cpu"</span> + <span class="n">devices</span> <span class="o">=</span> <span class="mi">1</span> + + <span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span> + <span class="n">accelerator</span><span class="o">=</span><span class="n">accelerator</span><span class="p">,</span> + <span class="n">devices</span><span class="o">=</span><span class="n">devices</span><span class="p">,</span> + <span class="n">max_epochs</span><span class="o">=</span><span class="n">max_epochs</span><span class="p">,</span> + <span class="n">callbacks</span><span class="o">=</span><span class="n">callbacks</span><span class="p">,</span> + <span class="n">log_every_n_steps</span><span class="o">=</span><span class="n">log_every_n_steps</span><span class="p">,</span> + <span class="n">logger</span><span class="o">=</span><span class="n">logger</span><span class="p">,</span> + <span class="n">gradient_clip_val</span><span class="o">=</span><span class="n">gradient_clip_val</span><span class="p">,</span> + <span class="n">strategy</span><span class="o">=</span><span class="n">distribution_strategy</span><span class="p">,</span> + <span class="n">default_root_dir</span><span class="o">=</span><span class="n">ckpt_path</span><span class="p">,</span> + <span class="o">**</span><span class="n">trainer_kwargs</span><span class="p">,</span> + <span class="p">)</span> + + <span class="k">return</span> <span class="n">trainer</span> + +<div class="viewcode-block" id="Model.fit"> +<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.fit">[docs]</a> + <span class="k">def</span> <span class="nf">fit</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">train_dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> + <span class="n">val_dataloader</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">DataLoader</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="o">*</span><span class="p">,</span> + <span class="n">max_epochs</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> + <span class="n">gpus</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">callbacks</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Callback</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">ckpt_path</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">logger</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LightningLogger</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">log_every_n_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> + <span class="n">gradient_clip_val</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">distribution_strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"ddp"</span><span class="p">,</span> + <span class="o">**</span><span class="n">trainer_kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Fit `Model` using `pytorch_lightning.Trainer`."""</span> + <span class="c1"># Checks</span> + <span class="k">if</span> <span class="n">callbacks</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">callbacks</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_default_callbacks</span><span class="p">(</span> + <span class="n">val_dataloader</span><span class="o">=</span><span class="n">val_dataloader</span><span class="p">,</span> + <span class="p">)</span> + <span class="k">elif</span> <span class="n">val_dataloader</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">callbacks</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_early_stopping</span><span class="p">(</span> + <span class="n">val_dataloader</span><span class="o">=</span><span class="n">val_dataloader</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="n">callbacks</span> + <span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> + <span class="n">trainer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_construct_trainer</span><span class="p">(</span> + <span class="n">max_epochs</span><span class="o">=</span><span class="n">max_epochs</span><span class="p">,</span> + <span class="n">gpus</span><span class="o">=</span><span class="n">gpus</span><span class="p">,</span> + <span class="n">callbacks</span><span class="o">=</span><span class="n">callbacks</span><span class="p">,</span> + <span class="n">ckpt_path</span><span class="o">=</span><span class="n">ckpt_path</span><span class="p">,</span> + <span class="n">logger</span><span class="o">=</span><span class="n">logger</span><span class="p">,</span> + <span class="n">log_every_n_steps</span><span class="o">=</span><span class="n">log_every_n_steps</span><span class="p">,</span> + <span class="n">gradient_clip_val</span><span class="o">=</span><span class="n">gradient_clip_val</span><span class="p">,</span> + <span class="n">distribution_strategy</span><span class="o">=</span><span class="n">distribution_strategy</span><span class="p">,</span> + <span class="o">**</span><span class="n">trainer_kwargs</span><span class="p">,</span> + <span class="p">)</span> + + <span class="k">try</span><span class="p">:</span> + <span class="n">trainer</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">train_dataloader</span><span class="p">,</span> <span class="n">val_dataloader</span><span class="p">,</span> <span class="n">ckpt_path</span><span class="o">=</span><span class="n">ckpt_path</span> + <span class="p">)</span> + <span class="k">except</span> <span class="ne">KeyboardInterrupt</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"[ctrl+c] Exiting gracefully."</span><span class="p">)</span> + <span class="k">pass</span></div> + + + <span class="k">def</span> <span class="nf">_create_default_callbacks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val_dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">:</span> + <span class="n">callbacks</span> <span class="o">=</span> <span class="p">[</span><span class="n">ProgressBar</span><span class="p">()]</span> + <span class="n">callbacks</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_add_early_stopping</span><span class="p">(</span> + <span class="n">val_dataloader</span><span class="o">=</span><span class="n">val_dataloader</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="n">callbacks</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">callbacks</span> + + <span class="k">def</span> <span class="nf">_add_early_stopping</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">val_dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">callbacks</span><span class="p">:</span> <span class="n">List</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">:</span> + <span class="k">if</span> <span class="n">val_dataloader</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="k">return</span> <span class="n">callbacks</span> + <span class="n">has_early_stopping</span> <span class="o">=</span> <span class="kc">False</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">callbacks</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> + <span class="k">for</span> <span class="n">callback</span> <span class="ow">in</span> <span class="n">callbacks</span><span class="p">:</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">callback</span><span class="p">,</span> <span class="n">EarlyStopping</span><span class="p">):</span> + <span class="n">has_early_stopping</span> <span class="o">=</span> <span class="kc">True</span> + + <span class="k">if</span> <span class="ow">not</span> <span class="n">has_early_stopping</span><span class="p">:</span> + <span class="n">callbacks</span><span class="o">.</span><span class="n">append</span><span class="p">(</span> + <span class="n">EarlyStopping</span><span class="p">(</span> + <span class="n">monitor</span><span class="o">=</span><span class="s2">"val_loss"</span><span class="p">,</span> + <span class="n">patience</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> + <span class="p">)</span> + <span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warning_once</span><span class="p">(</span> + <span class="s2">"Got validation dataloader but no EarlyStopping callback. An "</span> + <span class="s2">"EarlyStopping callback has been added automatically with "</span> + <span class="s2">"patience=5 and monitor = 'val_loss'."</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">callbacks</span> + +<div class="viewcode-block" id="Model.predict"> +<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.predict">[docs]</a> + <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> + <span class="n">gpus</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">distribution_strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"auto"</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Return predictions for `dataloader`.</span> + +<span class="sd"> Returns a list of Tensors, one for each model output.</span> +<span class="sd"> """</span> + <span class="bp">self</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> + + <span class="n">callbacks</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_default_callbacks</span><span class="p">(</span> + <span class="n">val_dataloader</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> + <span class="p">)</span> + + <span class="n">inference_trainer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_construct_trainer</span><span class="p">(</span> + <span class="n">gpus</span><span class="o">=</span><span class="n">gpus</span><span class="p">,</span> + <span class="n">distribution_strategy</span><span class="o">=</span><span class="n">distribution_strategy</span><span class="p">,</span> + <span class="n">callbacks</span><span class="o">=</span><span class="n">callbacks</span><span class="p">,</span> + <span class="p">)</span> + + <span class="n">predictions_list</span> <span class="o">=</span> <span class="n">inference_trainer</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataloader</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions_list</span><span class="p">),</span> <span class="s2">"Got no predictions"</span> + + <span class="n">nb_outputs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions_list</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> + <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span> + <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">preds</span><span class="p">[</span><span class="n">ix</span><span class="p">]</span> <span class="k">for</span> <span class="n">preds</span> <span class="ow">in</span> <span class="n">predictions_list</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> + <span class="k">for</span> <span class="n">ix</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nb_outputs</span><span class="p">)</span> + <span class="p">]</span> + + <span class="k">return</span> <span class="n">predictions</span></div> + + +<div class="viewcode-block" id="Model.predict_as_dataframe"> +<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.predict_as_dataframe">[docs]</a> + <span class="k">def</span> <span class="nf">predict_as_dataframe</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> + <span class="n">prediction_columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="o">*</span><span class="p">,</span> + <span class="n">additional_attributes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">gpus</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">distribution_strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"auto"</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return predictions for `dataloader` as a DataFrame.</span> + +<span class="sd"> Include `additional_attributes` as additional columns in the output</span> +<span class="sd"> DataFrame.</span> +<span class="sd"> """</span> + <span class="c1"># Check(s)</span> + <span class="k">if</span> <span class="n">additional_attributes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">additional_attributes</span> <span class="o">=</span> <span class="p">[]</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">additional_attributes</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> + + <span class="k">if</span> <span class="p">(</span> + <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataloader</span><span class="o">.</span><span class="n">sampler</span><span class="p">,</span> <span class="n">SequentialSampler</span><span class="p">)</span> + <span class="ow">and</span> <span class="n">additional_attributes</span> + <span class="p">):</span> + <span class="nb">print</span><span class="p">(</span><span class="n">dataloader</span><span class="o">.</span><span class="n">sampler</span><span class="p">)</span> + <span class="k">raise</span> <span class="ne">UserWarning</span><span class="p">(</span> + <span class="s2">"DataLoader has a `sampler` that is not `SequentialSampler`, "</span> + <span class="s2">"indicating that shuffling is enabled. Using "</span> + <span class="s2">"`predict_as_dataframe` with `additional_attributes` assumes "</span> + <span class="s2">"that the sequence of batches in `dataloader` are "</span> + <span class="s2">"deterministic. Either call this method a `dataloader` which "</span> + <span class="s2">"doesn't resample batches; or do not request "</span> + <span class="s2">"`additional_attributes`."</span> + <span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Column names for predictions are: </span><span class="se">\n</span><span class="s2"> </span><span class="si">{</span><span class="n">prediction_columns</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> + <span class="n">predictions_torch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span> + <span class="n">dataloader</span><span class="o">=</span><span class="n">dataloader</span><span class="p">,</span> + <span class="n">gpus</span><span class="o">=</span><span class="n">gpus</span><span class="p">,</span> + <span class="n">distribution_strategy</span><span class="o">=</span><span class="n">distribution_strategy</span><span class="p">,</span> + <span class="p">)</span> + <span class="n">predictions</span> <span class="o">=</span> <span class="p">(</span> + <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">predictions_torch</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> + <span class="p">)</span> + <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">prediction_columns</span><span class="p">)</span> <span class="o">==</span> <span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="p">(</span> + <span class="sa">f</span><span class="s2">"Number of provided column names (</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">prediction_columns</span><span class="p">)</span><span class="si">}</span><span class="s2">) and "</span> + <span class="sa">f</span><span class="s2">"number of output columns (</span><span class="si">{</span><span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2">) don't match."</span> + <span class="p">)</span> + + <span class="c1"># Get additional attributes</span> + <span class="n">attributes</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]]</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">(</span> + <span class="p">[(</span><span class="n">attr</span><span class="p">,</span> <span class="p">[])</span> <span class="k">for</span> <span class="n">attr</span> <span class="ow">in</span> <span class="n">additional_attributes</span><span class="p">]</span> + <span class="p">)</span> + + <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">dataloader</span><span class="p">:</span> + <span class="k">for</span> <span class="n">attr</span> <span class="ow">in</span> <span class="n">attributes</span><span class="p">:</span> + <span class="n">attribute</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">attr</span><span class="p">]</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">attribute</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span> + <span class="n">attribute</span> <span class="o">=</span> <span class="n">attribute</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> + + <span class="c1"># Check if node level predictions</span> + <span class="c1"># If true, additional attributes are repeated</span> + <span class="c1"># to make dimensions fit</span> + <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">dataloader</span><span class="o">.</span><span class="n">dataset</span><span class="p">):</span> + <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">attribute</span><span class="p">)</span> <span class="o"><</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span> + <span class="n">batch</span><span class="o">.</span><span class="n">n_pulses</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> + <span class="p">):</span> + <span class="n">attribute</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span> + <span class="n">attribute</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">n_pulses</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> + <span class="p">)</span> + <span class="k">try</span><span class="p">:</span> + <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">attribute</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">x</span><span class="p">)</span> + <span class="k">except</span> <span class="ne">AssertionError</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warning_once</span><span class="p">(</span> + <span class="s2">"Could not automatically adjust length"</span> + <span class="sa">f</span><span class="s2">"of additional attribute </span><span class="si">{</span><span class="n">attr</span><span class="si">}</span><span class="s2"> to match length of"</span> + <span class="sa">f</span><span class="s2">"predictions. Make sure </span><span class="si">{</span><span class="n">attr</span><span class="si">}</span><span class="s2"> is a graph-level or"</span> + <span class="s2">"node-level attribute. Attribute skipped."</span> + <span class="p">)</span> + <span class="k">pass</span> + <span class="n">attributes</span><span class="p">[</span><span class="n">attr</span><span class="p">]</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">attribute</span><span class="p">)</span> + + <span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span> + <span class="p">[</span><span class="n">predictions</span><span class="p">]</span> + <span class="o">+</span> <span class="p">[</span> + <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">values</span><span class="p">)[:,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span> + <span class="k">for</span> <span class="n">values</span> <span class="ow">in</span> <span class="n">attributes</span><span class="o">.</span><span class="n">values</span><span class="p">()</span> + <span class="p">],</span> + <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="n">results</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">(</span> + <span class="n">data</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="n">prediction_columns</span> <span class="o">+</span> <span class="n">additional_attributes</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">results</span></div> + + +<div class="viewcode-block" id="Model.save"> +<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.save">[docs]</a> + <span class="k">def</span> <span class="nf">save</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Save entire model to `path`."""</span> + <span class="k">if</span> <span class="ow">not</span> <span class="n">path</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s2">".pth"</span><span class="p">):</span> + <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span> + <span class="s2">"It is recommended to use the .pth suffix for model files."</span> + <span class="p">)</span> + <span class="n">dirname</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> + <span class="k">if</span> <span class="n">dirname</span><span class="p">:</span> + <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">dirname</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> + <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cpu</span><span class="p">(),</span> <span class="n">path</span><span class="p">,</span> <span class="n">pickle_module</span><span class="o">=</span><span class="n">dill</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Model saved to </span><span class="si">{</span><span class="n">path</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span></div> + + +<div class="viewcode-block" id="Model.load"> +<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.load">[docs]</a> + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"Model"</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Load entire model from `path`."""</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">pickle_module</span><span class="o">=</span><span class="n">dill</span><span class="p">)</span></div> + + +<div class="viewcode-block" id="Model.save_state_dict"> +<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.save_state_dict">[docs]</a> + <span class="k">def</span> <span class="nf">save_state_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Save model `state_dict` to `path`."""</span> + <span class="k">if</span> <span class="ow">not</span> <span class="n">path</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s2">".pth"</span><span class="p">):</span> + <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span> + <span class="s2">"It is recommended to use the .pth suffix for state_dict files."</span> + <span class="p">)</span> + <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">path</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Model state_dict saved to </span><span class="si">{</span><span class="n">path</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span></div> + + +<div class="viewcode-block" id="Model.load_state_dict"> +<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.load_state_dict">[docs]</a> + <span class="k">def</span> <span class="nf">load_state_dict</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Dict</span><span class="p">],</span> <span class="o">**</span><span class="n">kargs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Any</span><span class="p">]</span> + <span class="p">)</span> <span class="o">-></span> <span class="s2">"Model"</span><span class="p">:</span> <span class="c1"># pylint: disable=arguments-differ</span> +<span class="w"> </span><span class="sd">"""Load model `state_dict` from `path`."""</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span> + <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">state_dict</span> <span class="o">=</span> <span class="n">path</span> + <span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="o">**</span><span class="n">kargs</span><span class="p">)</span></div> + + +<div class="viewcode-block" id="Model.from_config"> +<a class="viewcode-back" href="../../../api/graphnet.models.model.html#graphnet.models.model.Model.from_config">[docs]</a> + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">from_config</span><span class="p">(</span> <span class="c1"># type: ignore[override]</span> + <span class="bp">cls</span><span class="p">,</span> + <span class="n">source</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">ModelConfig</span><span class="p">,</span> <span class="nb">str</span><span class="p">],</span> + <span class="n">trust</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> + <span class="n">load_modules</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="s2">"Model"</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct `Model` instance from `source` configuration.</span> + +<span class="sd"> Arguments:</span> +<span class="sd"> trust: Whether to trust the ModelConfig file enough to `eval(...)`</span> +<span class="sd"> any lambda function expressions contained.</span> +<span class="sd"> load_modules: List of modules used in the definition of the model</span> +<span class="sd"> which, as a consequence, need to be loaded into the global</span> +<span class="sd"> namespace. Defaults to loading `torch`.</span> + +<span class="sd"> Raises:</span> +<span class="sd"> ValueError: If the ModelConfig contains lambda functions but</span> +<span class="sd"> `trust = False`.</span> +<span class="sd"> """</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">source</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span> + <span class="n">source</span> <span class="o">=</span> <span class="n">ModelConfig</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">source</span><span class="p">)</span> + + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span> + <span class="n">source</span><span class="p">,</span> <span class="n">ModelConfig</span> + <span class="p">),</span> <span class="sa">f</span><span class="s2">"Argument `source` of type (</span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">source</span><span class="p">)</span><span class="si">}</span><span class="s2">) is not a `ModelConfig"</span> + + <span class="k">return</span> <span class="n">source</span><span class="o">.</span><span class="n">_construct_model</span><span class="p">(</span><span class="n">trust</span><span class="p">,</span> <span class="n">load_modules</span><span class="p">)</span></div> +</div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/standard_model.html b/_modules/graphnet/models/standard_model.html new file mode 100644 index 000000000..cb57cb523 --- /dev/null +++ b/_modules/graphnet/models/standard_model.html @@ -0,0 +1,604 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.standard_model — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" /> + <script src="../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../about.html" /> + <link rel="index" title="Index" href="../../../genindex.html" /> + <link rel="search" title="Search" href="../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/standard_model" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.standard_model </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../"versions.json"", + target_loc = "../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-standard-model--page-root">Source code for graphnet.models.standard_model</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Standard model class(es)."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span> + +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span> +<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">ModuleList</span> +<span class="kn">from</span> <span class="nn">torch.optim</span> <span class="kn">import</span> <span class="n">Adam</span> +<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> +<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span> +<span class="kn">from</span> <span class="nn">graphnet.models.graphs</span> <span class="kn">import</span> <span class="n">GraphDefinition</span> +<span class="kn">from</span> <span class="nn">graphnet.models.gnn.gnn</span> <span class="kn">import</span> <span class="n">GNN</span> +<span class="kn">from</span> <span class="nn">graphnet.models.model</span> <span class="kn">import</span> <span class="n">Model</span> +<span class="kn">from</span> <span class="nn">graphnet.models.task</span> <span class="kn">import</span> <span class="n">Task</span> + + +<div class="viewcode-block" id="StandardModel"> +<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel">[docs]</a> +<span class="k">class</span> <span class="nc">StandardModel</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Main class for standard models in graphnet.</span> + +<span class="sd"> This class chains together the different elements of a complete GNN-based</span> +<span class="sd"> model (detector read-in, GNN architecture, and task-specific read-outs).</span> +<span class="sd"> """</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="o">*</span><span class="p">,</span> + <span class="n">graph_definition</span><span class="p">:</span> <span class="n">GraphDefinition</span><span class="p">,</span> + <span class="n">gnn</span><span class="p">:</span> <span class="n">GNN</span><span class="p">,</span> + <span class="n">tasks</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Task</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="n">Task</span><span class="p">]],</span> + <span class="n">optimizer_class</span><span class="p">:</span> <span class="nb">type</span> <span class="o">=</span> <span class="n">Adam</span><span class="p">,</span> + <span class="n">optimizer_kwargs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">scheduler_class</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">type</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">scheduler_kwargs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">scheduler_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct `StandardModel`."""</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span> + + <span class="c1"># Check(s)</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">tasks</span><span class="p">,</span> <span class="n">Task</span><span class="p">):</span> + <span class="n">tasks</span> <span class="o">=</span> <span class="p">[</span><span class="n">tasks</span><span class="p">]</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">tasks</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">))</span> + <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">task</span><span class="p">,</span> <span class="n">Task</span><span class="p">)</span> <span class="k">for</span> <span class="n">task</span> <span class="ow">in</span> <span class="n">tasks</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">graph_definition</span><span class="p">,</span> <span class="n">GraphDefinition</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">gnn</span><span class="p">,</span> <span class="n">GNN</span><span class="p">)</span> + + <span class="c1"># Member variable(s)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_graph_definition</span> <span class="o">=</span> <span class="n">graph_definition</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_gnn</span> <span class="o">=</span> <span class="n">gnn</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_tasks</span> <span class="o">=</span> <span class="n">ModuleList</span><span class="p">(</span><span class="n">tasks</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_optimizer_class</span> <span class="o">=</span> <span class="n">optimizer_class</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_optimizer_kwargs</span> <span class="o">=</span> <span class="n">optimizer_kwargs</span> <span class="ow">or</span> <span class="nb">dict</span><span class="p">()</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_scheduler_class</span> <span class="o">=</span> <span class="n">scheduler_class</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_scheduler_kwargs</span> <span class="o">=</span> <span class="n">scheduler_kwargs</span> <span class="ow">or</span> <span class="nb">dict</span><span class="p">()</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_scheduler_config</span> <span class="o">=</span> <span class="n">scheduler_config</span> <span class="ow">or</span> <span class="nb">dict</span><span class="p">()</span> + + <span class="c1"># set dtype of GNN from graph_definition</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_gnn</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_graph_definition</span><span class="o">.</span><span class="n">_dtype</span><span class="p">)</span> + + <span class="nd">@property</span> + <span class="k">def</span> <span class="nf">target_labels</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Return target label."""</span> + <span class="k">return</span> <span class="p">[</span><span class="n">label</span> <span class="k">for</span> <span class="n">task</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tasks</span> <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">task</span><span class="o">.</span><span class="n">_target_labels</span><span class="p">]</span> + + <span class="nd">@property</span> + <span class="k">def</span> <span class="nf">prediction_labels</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Return prediction labels."""</span> + <span class="k">return</span> <span class="p">[</span> + <span class="n">label</span> <span class="k">for</span> <span class="n">task</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tasks</span> <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">task</span><span class="o">.</span><span class="n">_prediction_labels</span> + <span class="p">]</span> + +<div class="viewcode-block" id="StandardModel.configure_optimizers"> +<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.configure_optimizers">[docs]</a> + <span class="k">def</span> <span class="nf">configure_optimizers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Configure the model's optimizer(s)."""</span> + <span class="n">optimizer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_optimizer_class</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">_optimizer_kwargs</span> + <span class="p">)</span> + <span class="n">config</span> <span class="o">=</span> <span class="p">{</span> + <span class="s2">"optimizer"</span><span class="p">:</span> <span class="n">optimizer</span><span class="p">,</span> + <span class="p">}</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_scheduler_class</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">scheduler</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_scheduler_class</span><span class="p">(</span> + <span class="n">optimizer</span><span class="p">,</span> <span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">_scheduler_kwargs</span> + <span class="p">)</span> + <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">(</span> + <span class="p">{</span> + <span class="s2">"lr_scheduler"</span><span class="p">:</span> <span class="p">{</span> + <span class="s2">"scheduler"</span><span class="p">:</span> <span class="n">scheduler</span><span class="p">,</span> + <span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">_scheduler_config</span><span class="p">,</span> + <span class="p">},</span> + <span class="p">}</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">config</span></div> + + +<div class="viewcode-block" id="StandardModel.forward"> +<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.forward">[docs]</a> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">]]:</span> +<span class="w"> </span><span class="sd">"""Forward pass, chaining model components."""</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">Data</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_gnn</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> + <span class="n">preds</span> <span class="o">=</span> <span class="p">[</span><span class="n">task</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">task</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tasks</span><span class="p">]</span> + <span class="k">return</span> <span class="n">preds</span></div> + + +<div class="viewcode-block" id="StandardModel.shared_step"> +<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.shared_step">[docs]</a> + <span class="k">def</span> <span class="nf">shared_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Perform shared step.</span> + +<span class="sd"> Applies the forward pass and the following loss calculation, shared</span> +<span class="sd"> between the training and validation step.</span> +<span class="sd"> """</span> + <span class="n">preds</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span> + <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> + <span class="k">return</span> <span class="n">loss</span></div> + + +<div class="viewcode-block" id="StandardModel.training_step"> +<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.training_step">[docs]</a> + <span class="k">def</span> <span class="nf">training_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">train_batch</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Perform training step."""</span> + <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_step</span><span class="p">(</span><span class="n">train_batch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span> + <span class="s2">"train_loss"</span><span class="p">,</span> + <span class="n">loss</span><span class="p">,</span> + <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_batch_size</span><span class="p">(</span><span class="n">train_batch</span><span class="p">),</span> + <span class="n">prog_bar</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> + <span class="n">on_epoch</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> + <span class="n">on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> + <span class="n">sync_dist</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">loss</span></div> + + +<div class="viewcode-block" id="StandardModel.validation_step"> +<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.validation_step">[docs]</a> + <span class="k">def</span> <span class="nf">validation_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val_batch</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Perform validation step."""</span> + <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_step</span><span class="p">(</span><span class="n">val_batch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span> + <span class="s2">"val_loss"</span><span class="p">,</span> + <span class="n">loss</span><span class="p">,</span> + <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_batch_size</span><span class="p">(</span><span class="n">val_batch</span><span class="p">),</span> + <span class="n">prog_bar</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> + <span class="n">on_epoch</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> + <span class="n">on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> + <span class="n">sync_dist</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">loss</span></div> + + +<div class="viewcode-block" id="StandardModel.compute_loss"> +<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.compute_loss">[docs]</a> + <span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">,</span> <span class="n">verbose</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Compute and sum losses across tasks."""</span> + <span class="n">losses</span> <span class="o">=</span> <span class="p">[</span> + <span class="n">task</span><span class="o">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span> + <span class="k">for</span> <span class="n">task</span><span class="p">,</span> <span class="n">pred</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_tasks</span><span class="p">,</span> <span class="n">preds</span><span class="p">)</span> + <span class="p">]</span> + <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">losses</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> + <span class="k">assert</span> <span class="nb">all</span><span class="p">(</span> + <span class="n">loss</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">loss</span> <span class="ow">in</span> <span class="n">losses</span> + <span class="p">),</span> <span class="s2">"Please reduce loss for each task separately"</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">losses</span><span class="p">))</span></div> + + + <span class="k">def</span> <span class="nf">_get_batch_size</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">numel</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">batch</span><span class="p">))</span> + +<div class="viewcode-block" id="StandardModel.inference"> +<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.inference">[docs]</a> + <span class="k">def</span> <span class="nf">inference</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Activate inference mode."""</span> + <span class="k">for</span> <span class="n">task</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tasks</span><span class="p">:</span> + <span class="n">task</span><span class="o">.</span><span class="n">inference</span><span class="p">()</span></div> + + +<div class="viewcode-block" id="StandardModel.train"> +<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.train">[docs]</a> + <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"Model"</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Deactivate inference mode."""</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">mode</span><span class="p">)</span> + <span class="k">if</span> <span class="n">mode</span><span class="p">:</span> + <span class="k">for</span> <span class="n">task</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tasks</span><span class="p">:</span> + <span class="n">task</span><span class="o">.</span><span class="n">train_eval</span><span class="p">()</span> + <span class="k">return</span> <span class="bp">self</span></div> + + +<div class="viewcode-block" id="StandardModel.predict"> +<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.predict">[docs]</a> + <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> + <span class="n">gpus</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">distribution_strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"auto"</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Return predictions for `dataloader`."""</span> + <span class="bp">self</span><span class="o">.</span><span class="n">inference</span><span class="p">()</span> + <span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span> + <span class="n">dataloader</span><span class="o">=</span><span class="n">dataloader</span><span class="p">,</span> + <span class="n">gpus</span><span class="o">=</span><span class="n">gpus</span><span class="p">,</span> + <span class="n">distribution_strategy</span><span class="o">=</span><span class="n">distribution_strategy</span><span class="p">,</span> + <span class="p">)</span></div> + + +<div class="viewcode-block" id="StandardModel.predict_as_dataframe"> +<a class="viewcode-back" href="../../../api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.predict_as_dataframe">[docs]</a> + <span class="k">def</span> <span class="nf">predict_as_dataframe</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> + <span class="n">prediction_columns</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="o">*</span><span class="p">,</span> + <span class="n">additional_attributes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">gpus</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">distribution_strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"auto"</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return predictions for `dataloader` as a DataFrame.</span> + +<span class="sd"> Include `additional_attributes` as additional columns in the output</span> +<span class="sd"> DataFrame.</span> +<span class="sd"> """</span> + <span class="k">if</span> <span class="n">prediction_columns</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">prediction_columns</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prediction_labels</span> + <span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">predict_as_dataframe</span><span class="p">(</span> + <span class="n">dataloader</span><span class="o">=</span><span class="n">dataloader</span><span class="p">,</span> + <span class="n">prediction_columns</span><span class="o">=</span><span class="n">prediction_columns</span><span class="p">,</span> + <span class="n">additional_attributes</span><span class="o">=</span><span class="n">additional_attributes</span><span class="p">,</span> + <span class="n">gpus</span><span class="o">=</span><span class="n">gpus</span><span class="p">,</span> + <span class="n">distribution_strategy</span><span class="o">=</span><span class="n">distribution_strategy</span><span class="p">,</span> + <span class="p">)</span></div> +</div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/task/classification.html b/_modules/graphnet/models/task/classification.html new file mode 100644 index 000000000..b3c6c9440 --- /dev/null +++ b/_modules/graphnet/models/task/classification.html @@ -0,0 +1,411 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.task.classification — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/task/classification" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.task.classification </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-task-classification--page-root">Source code for graphnet.models.task.classification</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Classification-specific `Model` class(es)."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span> + +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span> + +<span class="kn">from</span> <span class="nn">graphnet.models.task</span> <span class="kn">import</span> <span class="n">Task</span><span class="p">,</span> <span class="n">IdentityTask</span> + + +<div class="viewcode-block" id="MulticlassClassificationTask"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.classification.html#graphnet.models.task.classification.MulticlassClassificationTask">[docs]</a> +<span class="k">class</span> <span class="nc">MulticlassClassificationTask</span><span class="p">(</span><span class="n">IdentityTask</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""General task for classifying any number of classes.</span> + +<span class="sd"> Requires the same number of input features as the number of classes being</span> +<span class="sd"> predicted. Returns the untransformed latent features, which are interpreted</span> +<span class="sd"> as the logits for each class being classified.</span> +<span class="sd"> """</span></div> + + + +<div class="viewcode-block" id="BinaryClassificationTask"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTask">[docs]</a> +<span class="k">class</span> <span class="nc">BinaryClassificationTask</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Performs binary classification."""</span> + + <span class="c1"># Requires one feature, logit for being signal class.</span> + <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">1</span> + <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"target"</span><span class="p">]</span> + <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"target_pred"</span><span class="p">]</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="c1"># transform probability of being muon</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div> + + + +<div class="viewcode-block" id="BinaryClassificationTaskLogits"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTaskLogits">[docs]</a> +<span class="k">class</span> <span class="nc">BinaryClassificationTaskLogits</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Performs binary classification form logits."""</span> + + <span class="c1"># Requires one feature, logit for being signal class.</span> + <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">1</span> + <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"target"</span><span class="p">]</span> + <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"target_pred"</span><span class="p">]</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">x</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/task/reconstruction.html b/_modules/graphnet/models/task/reconstruction.html new file mode 100644 index 000000000..ccb644a97 --- /dev/null +++ b/_modules/graphnet/models/task/reconstruction.html @@ -0,0 +1,609 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.task.reconstruction — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/task/reconstruction" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.task.reconstruction </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-task-reconstruction--page-root">Source code for graphnet.models.task.reconstruction</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Reconstruction-specific `Model` class(es)."""</span> + +<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span> + +<span class="kn">from</span> <span class="nn">graphnet.models.task</span> <span class="kn">import</span> <span class="n">Task</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.maths</span> <span class="kn">import</span> <span class="n">eps_like</span> + + +<div class="viewcode-block" id="AzimuthReconstructionWithKappa"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa">[docs]</a> +<span class="k">class</span> <span class="nc">AzimuthReconstructionWithKappa</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Reconstructs azimuthal angle and associated kappa (1/var)."""</span> + + <span class="c1"># Requires two features: untransformed points in (x,y)-space.</span> + <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"azimuth"</span><span class="p">]</span> + <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"azimuth_pred"</span><span class="p">,</span> <span class="s2">"azimuth_kappa"</span><span class="p">]</span> + <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">2</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="c1"># Transform outputs to angle and prepare prediction</span> + <span class="n">kappa</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">vector_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">eps_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + <span class="n">angle</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">atan2</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span> + <span class="n">angle</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span> + <span class="n">angle</span> <span class="o"><</span> <span class="mi">0</span><span class="p">,</span> <span class="n">angle</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">pi</span><span class="p">,</span> <span class="n">angle</span> + <span class="p">)</span> <span class="c1"># atan(y,x) -> [-pi, pi]</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">angle</span><span class="p">,</span> <span class="n">kappa</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></div> + + + +<div class="viewcode-block" id="AzimuthReconstruction"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstruction">[docs]</a> +<span class="k">class</span> <span class="nc">AzimuthReconstruction</span><span class="p">(</span><span class="n">AzimuthReconstructionWithKappa</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Reconstructs azimuthal angle."""</span> + + <span class="c1"># Requires two features: untransformed points in (x,y)-space.</span> + <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"azimuth"</span><span class="p">]</span> + <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"azimuth_pred"</span><span class="p">]</span> + <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">2</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="c1"># Transform outputs to angle and prepare prediction</span> + <span class="n">res</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + <span class="n">angle</span> <span class="o">=</span> <span class="n">res</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> + <span class="n">kappa</span> <span class="o">=</span> <span class="n">res</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> + <span class="n">sigma</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">/</span> <span class="n">kappa</span><span class="p">)</span> + <span class="n">beta</span> <span class="o">=</span> <span class="mf">1e-3</span> + <span class="n">kl_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">sigma</span><span class="o">**</span><span class="mi">2</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">sigma</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_regularisation_loss</span> <span class="o">+=</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">kl_loss</span> + <span class="k">return</span> <span class="n">angle</span></div> + + + +<div class="viewcode-block" id="DirectionReconstructionWithKappa"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa">[docs]</a> +<span class="k">class</span> <span class="nc">DirectionReconstructionWithKappa</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Reconstructs direction with kappa from the 3D-vMF distribution."""</span> + + <span class="c1"># Requires three features: untransformed points in (x,y,z)-space.</span> + <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span> + <span class="s2">"direction"</span> + <span class="p">]</span> <span class="c1"># contains dir_x, dir_y, dir_z see https://github.com/graphnet-team/graphnet/blob/95309556cfd46a4046bc4bd7609888aab649e295/src/graphnet/training/labels.py#L29</span> + <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span> + <span class="s2">"dir_x_pred"</span><span class="p">,</span> + <span class="s2">"dir_y_pred"</span><span class="p">,</span> + <span class="s2">"dir_z_pred"</span><span class="p">,</span> + <span class="s2">"direction_kappa"</span><span class="p">,</span> + <span class="p">]</span> + <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">3</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="c1"># Transform outputs to angle and prepare prediction</span> + <span class="n">kappa</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">vector_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">eps_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + <span class="n">vec_x</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">kappa</span> + <span class="n">vec_y</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">kappa</span> + <span class="n">vec_z</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="n">kappa</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">vec_x</span><span class="p">,</span> <span class="n">vec_y</span><span class="p">,</span> <span class="n">vec_z</span><span class="p">,</span> <span class="n">kappa</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></div> + + + +<div class="viewcode-block" id="ZenithReconstruction"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstruction">[docs]</a> +<span class="k">class</span> <span class="nc">ZenithReconstruction</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Reconstructs zenith angle."""</span> + + <span class="c1"># Requires two features: zenith angle itself.</span> + <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"zenith"</span><span class="p">]</span> + <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"zenith_pred"</span><span class="p">]</span> + <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">1</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="c1"># Transform outputs to angle and prepare prediction</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">1</span><span class="p">])</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">pi</span></div> + + + +<div class="viewcode-block" id="ZenithReconstructionWithKappa"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa">[docs]</a> +<span class="k">class</span> <span class="nc">ZenithReconstructionWithKappa</span><span class="p">(</span><span class="n">ZenithReconstruction</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Reconstructs zenith angle and associated kappa (1/var)."""</span> + + <span class="c1"># Requires one feature in addition to `ZenithReconstruction`: kappa (unceratinty; 1/variance).</span> + <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"zenith"</span><span class="p">]</span> + <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"zenith_pred"</span><span class="p">,</span> <span class="s2">"zenith_kappa"</span><span class="p">]</span> + <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">2</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="c1"># Transform outputs to angle and prepare prediction</span> + <span class="n">angle</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_forward</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> + <span class="n">kappa</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span> <span class="o">+</span> <span class="n">eps_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">angle</span><span class="p">,</span> <span class="n">kappa</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></div> + + + +<div class="viewcode-block" id="EnergyReconstruction"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstruction">[docs]</a> +<span class="k">class</span> <span class="nc">EnergyReconstruction</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Reconstructs energy using stable method."""</span> + + <span class="c1"># Requires one feature: untransformed energy</span> + <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"energy"</span><span class="p">]</span> + <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"energy_pred"</span><span class="p">]</span> + <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">1</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="c1"># Transform to positive energy domain avoiding `-inf` in `log10`</span> + <span class="c1"># Transform, thereby preventing overflow and underflow error.</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">softplus</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="mf">0.05</span><span class="p">)</span> <span class="o">+</span> <span class="n">eps_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div> + + + +<div class="viewcode-block" id="EnergyReconstructionWithPower"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithPower">[docs]</a> +<span class="k">class</span> <span class="nc">EnergyReconstructionWithPower</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Reconstructs energy."""</span> + + <span class="c1"># Requires one feature: untransformed energy</span> + <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"energy"</span><span class="p">]</span> + <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"energy_pred"</span><span class="p">]</span> + <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">1</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="c1"># Transform energy</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mf">1.0</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span></div> + + + +<div class="viewcode-block" id="EnergyReconstructionWithUncertainty"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty">[docs]</a> +<span class="k">class</span> <span class="nc">EnergyReconstructionWithUncertainty</span><span class="p">(</span><span class="n">EnergyReconstruction</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Reconstructs energy and associated uncertainty (log(var))."""</span> + + <span class="c1"># Requires one feature in addition to `EnergyReconstruction`: log-variance (uncertainty).</span> + <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"energy"</span><span class="p">]</span> + <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"energy_pred"</span><span class="p">,</span> <span class="s2">"energy_sigma"</span><span class="p">]</span> + <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">2</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="c1"># Transform energy</span> + <span class="n">energy</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_forward</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> + <span class="n">log_var</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> + <span class="n">pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">energy</span><span class="p">,</span> <span class="n">log_var</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> + <span class="k">return</span> <span class="n">pred</span></div> + + + +<div class="viewcode-block" id="VertexReconstruction"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.VertexReconstruction">[docs]</a> +<span class="k">class</span> <span class="nc">VertexReconstruction</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Reconstructs vertex position and time."""</span> + + <span class="c1"># Requires four features, x, y, z, and t.</span> + <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"vertex"</span><span class="p">]</span> + <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span> + <span class="s2">"position_x_pred"</span><span class="p">,</span> + <span class="s2">"position_y_pred"</span><span class="p">,</span> + <span class="s2">"position_z_pred"</span><span class="p">,</span> + <span class="s2">"interaction_time_pred"</span><span class="p">,</span> + <span class="p">]</span> + <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">4</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="c1"># Scale xyz to roughly the right order of magnitude, leave time</span> + <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1e2</span> + <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1e2</span> + <span class="n">x</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1e2</span> + + <span class="k">return</span> <span class="n">x</span></div> + + + +<div class="viewcode-block" id="PositionReconstruction"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.PositionReconstruction">[docs]</a> +<span class="k">class</span> <span class="nc">PositionReconstruction</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Reconstructs vertex position."""</span> + + <span class="c1"># Requires three features, x, y, and z.</span> + <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"position"</span><span class="p">]</span> + <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span> + <span class="s2">"position_x_pred"</span><span class="p">,</span> + <span class="s2">"position_y_pred"</span><span class="p">,</span> + <span class="s2">"position_z_pred"</span><span class="p">,</span> + <span class="p">]</span> + <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">3</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="c1"># Scale to roughly the right order of magnitude</span> + <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1e2</span> + <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1e2</span> + <span class="n">x</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1e2</span> + + <span class="k">return</span> <span class="n">x</span></div> + + + +<div class="viewcode-block" id="TimeReconstruction"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.TimeReconstruction">[docs]</a> +<span class="k">class</span> <span class="nc">TimeReconstruction</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Reconstructs time."""</span> + + <span class="c1"># Requires one feature, time.</span> + <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"interaction_time"</span><span class="p">]</span> + <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"interaction_time_pred"</span><span class="p">]</span> + <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">1</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="c1"># Leave as it is</span> + <span class="k">return</span> <span class="n">x</span></div> + + + +<div class="viewcode-block" id="InelasticityReconstruction"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.InelasticityReconstruction">[docs]</a> +<span class="k">class</span> <span class="nc">InelasticityReconstruction</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Reconstructs interaction inelasticity.</span> + +<span class="sd"> That is, 1-(track energy / hadronic energy).</span> +<span class="sd"> """</span> + + <span class="c1"># Requires one features: inelasticity itself</span> + <span class="n">default_target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"inelasticity"</span><span class="p">]</span> + <span class="n">default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"inelasticity_pred"</span><span class="p">]</span> + <span class="n">nb_inputs</span> <span class="o">=</span> <span class="mi">1</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="c1"># Transform output to unit range</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/task/task.html b/_modules/graphnet/models/task/task.html new file mode 100644 index 000000000..c43cb2f73 --- /dev/null +++ b/_modules/graphnet/models/task/task.html @@ -0,0 +1,688 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.task.task — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/task/task" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.task.task </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-task-task--page-root">Source code for graphnet.models.task.task</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Base physics task-specific `Model` class(es)."""</span> + +<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">TYPE_CHECKING</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Optional</span> +<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> + +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span> +<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">Linear</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> + +<span class="k">if</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span> + <span class="c1"># Avoid cyclic dependency</span> + <span class="kn">from</span> <span class="nn">graphnet.training.loss_functions</span> <span class="kn">import</span> <span class="n">LossFunction</span> <span class="c1"># type: ignore[attr-defined]</span> + +<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.decorators</span> <span class="kn">import</span> <span class="n">final</span> + + +<div class="viewcode-block" id="Task"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.task.html#graphnet.models.task.task.Task">[docs]</a> +<span class="k">class</span> <span class="nc">Task</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Base class for all reconstruction and classification tasks."""</span> + + <span class="nd">@property</span> + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">nb_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return number of inputs assumed by task."""</span> + + <span class="nd">@property</span> + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">default_target_labels</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Return default target labels."""</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_default_target_labels</span> + + <span class="nd">@property</span> + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">default_prediction_labels</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Return default prediction labels."""</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_default_prediction_labels</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="o">*</span><span class="p">,</span> + <span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> + <span class="n">loss_function</span><span class="p">:</span> <span class="s2">"LossFunction"</span><span class="p">,</span> + <span class="n">target_labels</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">prediction_labels</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">transform_prediction_and_target</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">transform_target</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">transform_inference</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">transform_support</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct `Task`.</span> + +<span class="sd"> Args:</span> +<span class="sd"> hidden_size: The number of nodes in the layer feeding into this</span> +<span class="sd"> tasks, used to construct the affine transformation to the</span> +<span class="sd"> predicted quantity.</span> +<span class="sd"> loss_function: Loss function appropriate to the task.</span> +<span class="sd"> target_labels: Name(s) of the quantity/-ies being predicted, used</span> +<span class="sd"> to extract the target tensor(s) from the `Data` object in</span> +<span class="sd"> `.compute_loss(...)`.</span> +<span class="sd"> prediction_labels: The name(s) of each column that is predicted by</span> +<span class="sd"> the model during inference. If not given, the name will auto</span> +<span class="sd"> matically be set to `target_label + _pred`.</span> +<span class="sd"> transform_prediction_and_target: Optional function to transform</span> +<span class="sd"> both the predicted and target tensor before passing them to the</span> +<span class="sd"> loss function. Useful e.g. for having the model predict</span> +<span class="sd"> quantities on a physical scale, but transforming this scale to</span> +<span class="sd"> O(1) for a numerically stable loss computation.</span> +<span class="sd"> transform_target: Optional function to transform only the target</span> +<span class="sd"> tensor before passing it, and the predicted tensor, to the loss</span> +<span class="sd"> function. Useful e.g. for having the model predict a</span> +<span class="sd"> transformed version of the target quantity, e.g. the log10-</span> +<span class="sd"> scaled energy, rather than the physical quantity itself. Used</span> +<span class="sd"> in conjunction with `transform_inference` to perform the</span> +<span class="sd"> inverse transform on the predicted quantity to recover the</span> +<span class="sd"> physical scale.</span> +<span class="sd"> transform_inference: Optional function to inverse-transform the</span> +<span class="sd"> model prediction to recover a physical scale. Used in</span> +<span class="sd"> conjunction with `transform_target`.</span> +<span class="sd"> transform_support: Optional tuple to specify minimum and maximum</span> +<span class="sd"> of the range of validity for the inverse transforms</span> +<span class="sd"> `transform_target` and `transform_inference` in case this is</span> +<span class="sd"> restricted. By default the invertibility of `transform_target`</span> +<span class="sd"> is tested on the range [-1e6, 1e6].</span> +<span class="sd"> loss_weight: Name of the attribute in `data` containing per-event</span> +<span class="sd"> loss weights.</span> +<span class="sd"> """</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> + <span class="c1"># Check(s)</span> + <span class="k">if</span> <span class="n">target_labels</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">target_labels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_target_labels</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">target_labels</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span> + <span class="n">target_labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">target_labels</span><span class="p">]</span> + + <span class="k">if</span> <span class="n">prediction_labels</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">prediction_labels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_prediction_labels</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">prediction_labels</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span> + <span class="n">prediction_labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">prediction_labels</span><span class="p">]</span> + + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">target_labels</span><span class="p">,</span> <span class="n">List</span><span class="p">)</span> <span class="c1"># mypy</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">prediction_labels</span><span class="p">,</span> <span class="n">List</span><span class="p">)</span> <span class="c1"># mypy</span> + <span class="c1"># Member variables</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_regularisation_loss</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_target_labels</span> <span class="o">=</span> <span class="n">target_labels</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_prediction_labels</span> <span class="o">=</span> <span class="n">prediction_labels</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_loss_function</span> <span class="o">=</span> <span class="n">loss_function</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_inference</span> <span class="o">=</span> <span class="kc">False</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight</span> <span class="o">=</span> <span class="n">loss_weight</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_transform_prediction_training</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[</span> + <span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">Tensor</span> + <span class="p">]</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_transform_prediction_inference</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[</span> + <span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">Tensor</span> + <span class="p">]</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_transform_target</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_validate_and_set_transforms</span><span class="p">(</span> + <span class="n">transform_prediction_and_target</span><span class="p">,</span> + <span class="n">transform_target</span><span class="p">,</span> + <span class="n">transform_inference</span><span class="p">,</span> + <span class="n">transform_support</span><span class="p">,</span> + <span class="p">)</span> + + <span class="c1"># Mapping from last hidden layer to required size of input</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_affine</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nb_inputs</span><span class="p">)</span> + +<div class="viewcode-block" id="Task.forward"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.task.html#graphnet.models.task.task.Task.forward">[docs]</a> + <span class="nd">@final</span> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">])</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Forward pass."""</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_regularisation_loss</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># Reset</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_affine</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transform_prediction</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div> + + + <span class="nd">@final</span> + <span class="k">def</span> <span class="nf">_transform_prediction</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">]</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">]:</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_inference</span><span class="p">:</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transform_prediction_inference</span><span class="p">(</span><span class="n">prediction</span><span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transform_prediction_training</span><span class="p">(</span><span class="n">prediction</span><span class="p">)</span> + + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">])</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Syntax like `.forward`, for implentation in inheriting classes."""</span> + +<div class="viewcode-block" id="Task.compute_loss"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.task.html#graphnet.models.task.task.Task.compute_loss">[docs]</a> + <span class="nd">@final</span> + <span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Data</span><span class="p">],</span> <span class="n">data</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Compute loss of `pred` wrt.</span> + +<span class="sd"> target labels in `data`.</span> +<span class="sd"> """</span> + <span class="n">target</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span> + <span class="p">[</span><span class="n">data</span><span class="p">[</span><span class="n">label</span><span class="p">]</span> <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_target_labels</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span> + <span class="p">)</span> + <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transform_target</span><span class="p">(</span><span class="n">target</span><span class="p">)</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">weights</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_loss_weight</span><span class="p">]</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">weights</span> <span class="o">=</span> <span class="kc">None</span> + <span class="n">loss</span> <span class="o">=</span> <span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_loss_function</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="n">weights</span><span class="p">)</span> + <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">_regularisation_loss</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">loss</span></div> + + +<div class="viewcode-block" id="Task.inference"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.task.html#graphnet.models.task.task.Task.inference">[docs]</a> + <span class="nd">@final</span> + <span class="k">def</span> <span class="nf">inference</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Activate inference mode."""</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_inference</span> <span class="o">=</span> <span class="kc">True</span></div> + + +<div class="viewcode-block" id="Task.train_eval"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.task.html#graphnet.models.task.task.Task.train_eval">[docs]</a> + <span class="nd">@final</span> + <span class="k">def</span> <span class="nf">train_eval</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Deactivate inference mode."""</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_inference</span> <span class="o">=</span> <span class="kc">False</span></div> + + + <span class="nd">@final</span> + <span class="k">def</span> <span class="nf">_validate_and_set_transforms</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">transform_prediction_and_target</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Callable</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> + <span class="n">transform_target</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Callable</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> + <span class="n">transform_inference</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Callable</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> + <span class="n">transform_support</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tuple</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> + <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Validate and set transforms.</span> + +<span class="sd"> Assert that a valid combination of transformation arguments are passed</span> +<span class="sd"> and update the corresponding functions.</span> +<span class="sd"> """</span> + <span class="c1"># Checks</span> + <span class="k">assert</span> <span class="ow">not</span> <span class="p">(</span> + <span class="p">(</span><span class="n">transform_prediction_and_target</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> + <span class="ow">and</span> <span class="p">(</span><span class="n">transform_target</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> + <span class="p">),</span> <span class="s2">"Please specify at most one of `transform_prediction_and_target` and `transform_target`"</span> + <span class="k">if</span> <span class="p">(</span><span class="n">transform_target</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">!=</span> <span class="p">(</span><span class="n">transform_inference</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">):</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span> + <span class="s2">"Setting one of `transform_target` and `transform_inference`, but not "</span> + <span class="s2">"the other."</span> + <span class="p">)</span> + + <span class="k">if</span> <span class="n">transform_target</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="k">assert</span> <span class="n">transform_target</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> + <span class="k">assert</span> <span class="n">transform_inference</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> + + <span class="k">if</span> <span class="n">transform_support</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="k">assert</span> <span class="n">transform_support</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> + + <span class="k">assert</span> <span class="p">(</span> + <span class="nb">len</span><span class="p">(</span><span class="n">transform_support</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span> + <span class="p">),</span> <span class="s2">"Please specify min and max for transformation support."</span> + <span class="n">x_test</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span> + <span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">transform_support</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">transform_support</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="mi">10</span><span class="p">)</span> + <span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">x_test</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">logspace</span><span class="p">(</span><span class="o">-</span><span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">12</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> + <span class="n">x_test</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span> + <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="o">-</span><span class="n">x_test</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">x_test</span><span class="p">])</span> + <span class="p">)</span> + + <span class="c1"># Add feature dimension before inference transformation to make it</span> + <span class="c1"># match the dimensions of a standard prediction. Remove it again</span> + <span class="c1"># before comparison. Temporary</span> + <span class="k">try</span><span class="p">:</span> + <span class="n">t_test</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">transform_target</span><span class="p">(</span><span class="n">x_test</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> + <span class="n">t_test</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">transform_inference</span><span class="p">(</span><span class="n">t_test</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> + <span class="n">valid</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">isfinite</span><span class="p">(</span><span class="n">t_test</span><span class="p">)</span> + + <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">t_test</span><span class="p">[</span><span class="n">valid</span><span class="p">],</span> <span class="n">x_test</span><span class="p">[</span><span class="n">valid</span><span class="p">]),</span> <span class="p">(</span> + <span class="s2">"The provided transforms for targets during training and "</span> + <span class="s2">"predictions during inference are not inverse. Please "</span> + <span class="s2">"adjust transformation functions or support."</span> + <span class="p">)</span> + <span class="k">del</span> <span class="n">x_test</span><span class="p">,</span> <span class="n">t_test</span><span class="p">,</span> <span class="n">valid</span> + + <span class="k">except</span> <span class="ne">IndexError</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span> + <span class="s2">"transform_target and/or transform_inference rely on "</span> + <span class="s2">"indexing, which we won't validate. Please make sure that "</span> + <span class="s2">"they are mutually inverse, i.e. that</span><span class="se">\n</span><span class="s2">"</span> + <span class="s2">" x = transform_inference(transform_target(x))</span><span class="se">\n</span><span class="s2">"</span> + <span class="s2">"for all x that are within your target range."</span> + <span class="p">)</span> + + <span class="c1"># Set transforms</span> + <span class="k">if</span> <span class="n">transform_prediction_and_target</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_transform_prediction_training</span> <span class="o">=</span> <span class="p">(</span> + <span class="n">transform_prediction_and_target</span> + <span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_transform_target</span> <span class="o">=</span> <span class="n">transform_prediction_and_target</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">if</span> <span class="n">transform_target</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_transform_target</span> <span class="o">=</span> <span class="n">transform_target</span> + <span class="k">if</span> <span class="n">transform_inference</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_transform_prediction_inference</span> <span class="o">=</span> <span class="n">transform_inference</span></div> + + + +<div class="viewcode-block" id="IdentityTask"> +<a class="viewcode-back" href="../../../../api/graphnet.models.task.task.html#graphnet.models.task.task.IdentityTask">[docs]</a> +<span class="k">class</span> <span class="nc">IdentityTask</span><span class="p">(</span><span class="n">Task</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Identity, or trivial, task."""</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">nb_outputs</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> + <span class="n">target_labels</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="n">Any</span><span class="p">],</span> + <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> + <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct IdentityTask.</span> + +<span class="sd"> Return the `nb_outputs` as a direct, affine transformation of the last</span> +<span class="sd"> hidden layer.</span> +<span class="sd"> """</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span> <span class="o">=</span> <span class="n">nb_outputs</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_default_target_labels</span> <span class="o">=</span> <span class="p">(</span> + <span class="n">target_labels</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">target_labels</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> + <span class="k">else</span> <span class="p">[</span><span class="n">target_labels</span><span class="p">]</span> + <span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_default_prediction_labels</span> <span class="o">=</span> <span class="p">[</span> + <span class="sa">f</span><span class="s2">"target_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">_pred"</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_default_target_labels</span><span class="p">))</span> + <span class="p">]</span> + + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> + <span class="c1"># Base class constructor</span> + + <span class="nd">@property</span> + <span class="k">def</span> <span class="nf">default_target_labels</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Return default target labels."""</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_default_target_labels</span> + + <span class="nd">@property</span> + <span class="k">def</span> <span class="nf">default_prediction_labels</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Return default prediction labels."""</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_default_prediction_labels</span> + + <span class="nd">@property</span> + <span class="k">def</span> <span class="nf">nb_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return number of inputs assumed by task."""</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_inputs</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="c1"># Leave it as is.</span> + <span class="k">return</span> <span class="n">x</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/models/utils.html b/_modules/graphnet/models/utils.html new file mode 100644 index 000000000..35c1b3460 --- /dev/null +++ b/_modules/graphnet/models/utils.html @@ -0,0 +1,430 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.models.utils — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" /> + <script src="../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../about.html" /> + <link rel="index" title="Index" href="../../../genindex.html" /> + <link rel="search" title="Search" href="../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/models/utils" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.models.utils </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../"versions.json"", + target_loc = "../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-models-utils--page-root">Source code for graphnet.models.utils</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Utility functions for `graphnet.models`."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span> +<span class="kn">from</span> <span class="nn">torch_geometric.nn</span> <span class="kn">import</span> <span class="n">knn_graph</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Batch</span> +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">LongTensor</span> + +<span class="kn">from</span> <span class="nn">torch_geometric.utils.homophily</span> <span class="kn">import</span> <span class="n">homophily</span> + + +<div class="viewcode-block" id="calculate_xyzt_homophily"> +<a class="viewcode-back" href="../../../api/graphnet.models.utils.html#graphnet.models.utils.calculate_xyzt_homophily">[docs]</a> +<span class="k">def</span> <span class="nf">calculate_xyzt_homophily</span><span class="p">(</span> + <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">:</span> <span class="n">LongTensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Batch</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Calculate xyzt-homophily from a batch of graphs.</span> + +<span class="sd"> Homophily is a graph scalar quantity that measures the likeness of</span> +<span class="sd"> variables in nodes. Notice that this calculator assumes a special order of</span> +<span class="sd"> input features in x.</span> + +<span class="sd"> Returns:</span> +<span class="sd"> Tuple, each element with shape [batch_size,1].</span> +<span class="sd"> """</span> + <span class="n">hx</span> <span class="o">=</span> <span class="n">homophily</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">batch</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> + <span class="n">hy</span> <span class="o">=</span> <span class="n">homophily</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">batch</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> + <span class="n">hz</span> <span class="o">=</span> <span class="n">homophily</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">batch</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> + <span class="n">ht</span> <span class="o">=</span> <span class="n">homophily</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">],</span> <span class="n">batch</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> + <span class="k">return</span> <span class="n">hx</span><span class="p">,</span> <span class="n">hy</span><span class="p">,</span> <span class="n">hz</span><span class="p">,</span> <span class="n">ht</span></div> + + + +<div class="viewcode-block" id="calculate_distance_matrix"> +<a class="viewcode-back" href="../../../api/graphnet.models.utils.html#graphnet.models.utils.calculate_distance_matrix">[docs]</a> +<span class="k">def</span> <span class="nf">calculate_distance_matrix</span><span class="p">(</span><span class="n">xyz_coords</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Calculate the matrix of pairwise distances between pulses.</span> + +<span class="sd"> Args:</span> +<span class="sd"> xyz_coords: (x,y,z)-coordinates of pulses, of shape [nb_doms, 3].</span> + +<span class="sd"> Returns:</span> +<span class="sd"> Matrix of pairwise distances, of shape [nb_doms, nb_doms]</span> +<span class="sd"> """</span> + <span class="n">diff</span> <span class="o">=</span> <span class="n">xyz_coords</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="n">xyz_coords</span><span class="o">.</span><span class="n">T</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">diff</span><span class="o">**</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span></div> + + + +<div class="viewcode-block" id="knn_graph_batch"> +<a class="viewcode-back" href="../../../api/graphnet.models.utils.html#graphnet.models.utils.knn_graph_batch">[docs]</a> +<span class="k">def</span> <span class="nf">knn_graph_batch</span><span class="p">(</span><span class="n">batch</span><span class="p">:</span> <span class="n">Batch</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Batch</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Calculate k-nearest-neighbours with individual k for each batch event.</span> + +<span class="sd"> Args:</span> +<span class="sd"> batch: Batch of events.</span> +<span class="sd"> k: A list of k's.</span> +<span class="sd"> columns: The columns of Data.x used for computing the distances. E.g.,</span> +<span class="sd"> Data.x[:,[0,1,2]]</span> + +<span class="sd"> Returns:</span> +<span class="sd"> Returns the same batch of events, but with updated edges.</span> +<span class="sd"> """</span> + <span class="n">data_list</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">to_data_list</span><span class="p">()</span> + <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">data_list</span><span class="p">)):</span> + <span class="n">data_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">edge_index</span> <span class="o">=</span> <span class="n">knn_graph</span><span class="p">(</span> + <span class="n">x</span><span class="o">=</span><span class="n">data_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">x</span><span class="p">[:,</span> <span class="n">columns</span><span class="p">],</span> <span class="n">k</span><span class="o">=</span><span class="n">k</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">Batch</span><span class="o">.</span><span class="n">from_data_list</span><span class="p">(</span><span class="n">data_list</span><span class="p">)</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/pisa/fitting.html b/_modules/graphnet/pisa/fitting.html index 8be91cc6e..1865353e4 100644 --- a/_modules/graphnet/pisa/fitting.html +++ b/_modules/graphnet/pisa/fitting.html @@ -1169,7 +1169,7 @@ <h1 id="modules-graphnet-pisa-fitting--page-root">Source code for graphnet.pisa. </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/pisa/plotting.html b/_modules/graphnet/pisa/plotting.html index c1cbdd4a9..6f31c1d90 100644 --- a/_modules/graphnet/pisa/plotting.html +++ b/_modules/graphnet/pisa/plotting.html @@ -528,7 +528,7 @@ <h1 id="modules-graphnet-pisa-plotting--page-root">Source code for graphnet.pisa </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/training/callbacks.html b/_modules/graphnet/training/callbacks.html new file mode 100644 index 000000000..fcf981d92 --- /dev/null +++ b/_modules/graphnet/training/callbacks.html @@ -0,0 +1,544 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.training.callbacks — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" /> + <script src="../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../about.html" /> + <link rel="index" title="Index" href="../../../genindex.html" /> + <link rel="search" title="Search" href="../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/training/callbacks" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.training.callbacks </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../"versions.json"", + target_loc = "../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-training-callbacks--page-root">Source code for graphnet.training.callbacks</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Callback class(es) for using during model training."""</span> + +<span class="kn">import</span> <span class="nn">logging</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span> +<span class="kn">import</span> <span class="nn">warnings</span> + +<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> +<span class="kn">from</span> <span class="nn">tqdm.std</span> <span class="kn">import</span> <span class="n">Bar</span> + +<span class="kn">from</span> <span class="nn">pytorch_lightning</span> <span class="kn">import</span> <span class="n">LightningModule</span><span class="p">,</span> <span class="n">Trainer</span> +<span class="kn">from</span> <span class="nn">pytorch_lightning.callbacks</span> <span class="kn">import</span> <span class="n">TQDMProgressBar</span> +<span class="kn">from</span> <span class="nn">pytorch_lightning.utilities</span> <span class="kn">import</span> <span class="n">rank_zero_only</span> +<span class="kn">from</span> <span class="nn">torch.optim</span> <span class="kn">import</span> <span class="n">Optimizer</span> +<span class="kn">from</span> <span class="nn">torch.optim.lr_scheduler</span> <span class="kn">import</span> <span class="n">_LRScheduler</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span> + + +<div class="viewcode-block" id="PiecewiseLinearLR"> +<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.PiecewiseLinearLR">[docs]</a> +<span class="k">class</span> <span class="nc">PiecewiseLinearLR</span><span class="p">(</span><span class="n">_LRScheduler</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Interpolate learning rate linearly between milestones."""</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">optimizer</span><span class="p">:</span> <span class="n">Optimizer</span><span class="p">,</span> + <span class="n">milestones</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> + <span class="n">factors</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span> + <span class="n">last_epoch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> + <span class="n">verbose</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct `PiecewiseLinearLR`.</span> + +<span class="sd"> For each milestone, denoting a specified number of steps, a factor</span> +<span class="sd"> multiplying the base learning rate is specified. For steps between two</span> +<span class="sd"> milestones, the learning rate is interpolated linearly between the two</span> +<span class="sd"> closest milestones. For steps before the first milestone, the factor</span> +<span class="sd"> for the first milestone is used; vice versa for steps after the last</span> +<span class="sd"> milestone.</span> + +<span class="sd"> Args:</span> +<span class="sd"> optimizer: Wrapped optimizer.</span> +<span class="sd"> milestones: List of step indices. Must be increasing.</span> +<span class="sd"> factors: List of multiplicative factors. Must be same length as</span> +<span class="sd"> `milestones`.</span> +<span class="sd"> last_epoch: The index of the last epoch.</span> +<span class="sd"> verbose: If ``True``, prints a message to stdout for each update.</span> +<span class="sd"> """</span> + <span class="c1"># Check(s)</span> + <span class="k">if</span> <span class="n">milestones</span> <span class="o">!=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">milestones</span><span class="p">):</span> + <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Milestones must be increasing"</span><span class="p">)</span> + <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">milestones</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">factors</span><span class="p">):</span> + <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> + <span class="s2">"Only multiplicative factor must be specified for each milestone."</span> + <span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">milestones</span> <span class="o">=</span> <span class="n">milestones</span> + <span class="bp">self</span><span class="o">.</span><span class="n">factors</span> <span class="o">=</span> <span class="n">factors</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">last_epoch</span><span class="p">,</span> <span class="n">verbose</span><span class="p">)</span> + + <span class="k">def</span> <span class="nf">_get_factor</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span> + <span class="c1"># Linearly interpolate multiplicative factor between milestones.</span> + <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">interp</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">last_epoch</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">milestones</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">factors</span><span class="p">)</span> + +<div class="viewcode-block" id="PiecewiseLinearLR.get_lr"> +<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.PiecewiseLinearLR.get_lr">[docs]</a> + <span class="k">def</span> <span class="nf">get_lr</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Get effective learning rate(s) for each optimizer."""</span> + <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_lr_called_within_step</span><span class="p">:</span> + <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span> + <span class="s2">"To get the last learning rate computed by the scheduler, "</span> + <span class="s2">"please use `get_last_lr()`."</span><span class="p">,</span> + <span class="ne">UserWarning</span><span class="p">,</span> + <span class="p">)</span> + + <span class="k">return</span> <span class="p">[</span><span class="n">base_lr</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_factor</span><span class="p">()</span> <span class="k">for</span> <span class="n">base_lr</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">base_lrs</span><span class="p">]</span></div> +</div> + + + +<div class="viewcode-block" id="ProgressBar"> +<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar">[docs]</a> +<span class="k">class</span> <span class="nc">ProgressBar</span><span class="p">(</span><span class="n">TQDMProgressBar</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Custom progress bar for graphnet.</span> + +<span class="sd"> Customises the default progress in pytorch-lightning.</span> +<span class="sd"> """</span> + + <span class="k">def</span> <span class="nf">_common_config</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">bar</span><span class="p">:</span> <span class="n">Bar</span><span class="p">)</span> <span class="o">-></span> <span class="n">Bar</span><span class="p">:</span> + <span class="n">bar</span><span class="o">.</span><span class="n">unit</span> <span class="o">=</span> <span class="s2">" batch(es)"</span> + <span class="n">bar</span><span class="o">.</span><span class="n">colour</span> <span class="o">=</span> <span class="s2">"green"</span> + <span class="k">return</span> <span class="n">bar</span> + +<div class="viewcode-block" id="ProgressBar.init_validation_tqdm"> +<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_validation_tqdm">[docs]</a> + <span class="k">def</span> <span class="nf">init_validation_tqdm</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Bar</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Override for customisation."""</span> + <span class="n">bar</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">init_validation_tqdm</span><span class="p">()</span> + <span class="n">bar</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_common_config</span><span class="p">(</span><span class="n">bar</span><span class="p">)</span> + <span class="k">return</span> <span class="n">bar</span></div> + + +<div class="viewcode-block" id="ProgressBar.init_predict_tqdm"> +<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_predict_tqdm">[docs]</a> + <span class="k">def</span> <span class="nf">init_predict_tqdm</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Bar</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Override for customisation."""</span> + <span class="n">bar</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">init_predict_tqdm</span><span class="p">()</span> + <span class="n">bar</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_common_config</span><span class="p">(</span><span class="n">bar</span><span class="p">)</span> + <span class="k">return</span> <span class="n">bar</span></div> + + +<div class="viewcode-block" id="ProgressBar.init_test_tqdm"> +<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_test_tqdm">[docs]</a> + <span class="k">def</span> <span class="nf">init_test_tqdm</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Bar</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Override for customisation."""</span> + <span class="n">bar</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">init_test_tqdm</span><span class="p">()</span> + <span class="n">bar</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_common_config</span><span class="p">(</span><span class="n">bar</span><span class="p">)</span> + <span class="k">return</span> <span class="n">bar</span></div> + + +<div class="viewcode-block" id="ProgressBar.init_train_tqdm"> +<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_train_tqdm">[docs]</a> + <span class="k">def</span> <span class="nf">init_train_tqdm</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Bar</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Override for customisation."""</span> + <span class="n">bar</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">init_train_tqdm</span><span class="p">()</span> + <span class="n">bar</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_common_config</span><span class="p">(</span><span class="n">bar</span><span class="p">)</span> + <span class="k">return</span> <span class="n">bar</span></div> + + +<div class="viewcode-block" id="ProgressBar.get_metrics"> +<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.get_metrics">[docs]</a> + <span class="k">def</span> <span class="nf">get_metrics</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">trainer</span><span class="p">:</span> <span class="n">Trainer</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">LightningModule</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Override to not show the version number in the logging."""</span> + <span class="n">items</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_metrics</span><span class="p">(</span><span class="n">trainer</span><span class="p">,</span> <span class="n">model</span><span class="p">)</span> + <span class="n">items</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">"v_num"</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> + <span class="k">return</span> <span class="n">items</span></div> + + +<div class="viewcode-block" id="ProgressBar.on_train_epoch_start"> +<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.on_train_epoch_start">[docs]</a> + <span class="k">def</span> <span class="nf">on_train_epoch_start</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">trainer</span><span class="p">:</span> <span class="n">Trainer</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">LightningModule</span> + <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Print the results of the previous epoch on a separate line.</span> + +<span class="sd"> This allows the user to see the losses/metrics for previous epochs</span> +<span class="sd"> while the current is training. The default behaviour in pytorch-</span> +<span class="sd"> lightning is to overwrite the progress bar from previous epochs.</span> +<span class="sd"> """</span> + <span class="k">if</span> <span class="n">trainer</span><span class="o">.</span><span class="n">current_epoch</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">train_progress_bar</span><span class="o">.</span><span class="n">set_postfix</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">get_metrics</span><span class="p">(</span><span class="n">trainer</span><span class="p">,</span> <span class="n">model</span><span class="p">)</span> + <span class="p">)</span> + <span class="nb">print</span><span class="p">(</span><span class="s2">""</span><span class="p">)</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">on_train_epoch_start</span><span class="p">(</span><span class="n">trainer</span><span class="p">,</span> <span class="n">model</span><span class="p">)</span> + <span class="bp">self</span><span class="o">.</span><span class="n">train_progress_bar</span><span class="o">.</span><span class="n">set_description</span><span class="p">(</span> + <span class="sa">f</span><span class="s2">"Epoch </span><span class="si">{</span><span class="n">trainer</span><span class="o">.</span><span class="n">current_epoch</span><span class="si">:</span><span class="s2">2d</span><span class="si">}</span><span class="s2">"</span> + <span class="p">)</span></div> + + +<div class="viewcode-block" id="ProgressBar.on_train_epoch_end"> +<a class="viewcode-back" href="../../../api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.on_train_epoch_end">[docs]</a> + <span class="k">def</span> <span class="nf">on_train_epoch_end</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">trainer</span><span class="p">:</span> <span class="n">Trainer</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">LightningModule</span> + <span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Log the final progress bar for the epoch to file.</span> + +<span class="sd"> Don't duplciate to stdout.</span> +<span class="sd"> """</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">on_train_epoch_end</span><span class="p">(</span><span class="n">trainer</span><span class="p">,</span> <span class="n">model</span><span class="p">)</span> + + <span class="k">if</span> <span class="n">rank_zero_only</span><span class="o">.</span><span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> + <span class="c1"># Construct Logger</span> + <span class="n">logger</span> <span class="o">=</span> <span class="n">Logger</span><span class="p">()</span> + + <span class="c1"># Log only to file, not stream</span> + <span class="n">h</span> <span class="o">=</span> <span class="n">logger</span><span class="o">.</span><span class="n">handlers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">logging</span><span class="o">.</span><span class="n">StreamHandler</span><span class="p">)</span> + <span class="n">level</span> <span class="o">=</span> <span class="n">h</span><span class="o">.</span><span class="n">level</span> + <span class="n">h</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><span class="n">logging</span><span class="o">.</span><span class="n">ERROR</span><span class="p">)</span> + <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">train_progress_bar</span><span class="p">))</span> + <span class="n">h</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span></div> +</div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/training/labels.html b/_modules/graphnet/training/labels.html new file mode 100644 index 000000000..b2dd18a77 --- /dev/null +++ b/_modules/graphnet/training/labels.html @@ -0,0 +1,436 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.training.labels — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" /> + <script src="../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../about.html" /> + <link rel="index" title="Index" href="../../../genindex.html" /> + <link rel="search" title="Search" href="../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/training/labels" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.training.labels </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../"versions.json"", + target_loc = "../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-training-labels--page-root">Source code for graphnet.training.labels</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Class(es) for constructing training labels at runtime."""</span> + +<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span> +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Data</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span> + + +<div class="viewcode-block" id="Label"> +<a class="viewcode-back" href="../../../api/graphnet.training.labels.html#graphnet.training.labels.Label">[docs]</a> +<span class="k">class</span> <span class="nc">Label</span><span class="p">(</span><span class="n">ABC</span><span class="p">,</span> <span class="n">Logger</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Base `Label` class for producing labels from single `Data` instance."""</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct `Label`.</span> + +<span class="sd"> Args:</span> +<span class="sd"> key: The name of the field in `Data` where the label will be</span> +<span class="sd"> stored. That is, `graph[key] = label`.</span> +<span class="sd"> """</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_key</span> <span class="o">=</span> <span class="n">key</span> + + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">class_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span> + + <span class="nd">@property</span> + <span class="k">def</span> <span class="nf">key</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return value of `key`."""</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_key</span> + + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Label-specific implementation."""</span></div> + + + +<div class="viewcode-block" id="Direction"> +<a class="viewcode-back" href="../../../api/graphnet.training.labels.html#graphnet.training.labels.Direction">[docs]</a> +<span class="k">class</span> <span class="nc">Direction</span><span class="p">(</span><span class="n">Label</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Class for producing particle direction/pointing label."""</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">key</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"direction"</span><span class="p">,</span> + <span class="n">azimuth_key</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"azimuth"</span><span class="p">,</span> + <span class="n">zenith_key</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"zenith"</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct `Direction`.</span> + +<span class="sd"> Args:</span> +<span class="sd"> key: The name of the field in `Data` where the label will be</span> +<span class="sd"> stored. That is, `graph[key] = label`.</span> +<span class="sd"> azimuth_key: The name of the pre-existing key in `graph` that will</span> +<span class="sd"> be used to access the azimiuth angle, used when calculating</span> +<span class="sd"> the direction.</span> +<span class="sd"> zenith_key: The name of the pre-existing key in `graph` that will</span> +<span class="sd"> be used to access the zenith angle, used when calculating the</span> +<span class="sd"> direction.</span> +<span class="sd"> """</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_azimuth_key</span> <span class="o">=</span> <span class="n">azimuth_key</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_zenith_key</span> <span class="o">=</span> <span class="n">zenith_key</span> + + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">key</span><span class="o">=</span><span class="n">key</span><span class="p">)</span> + + <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">:</span> <span class="n">Data</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Compute label for `graph`."""</span> + <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">graph</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_azimuth_key</span><span class="p">])</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span> + <span class="n">graph</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_zenith_key</span><span class="p">]</span> + <span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> + <span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">graph</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_azimuth_key</span><span class="p">])</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span> + <span class="n">graph</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_zenith_key</span><span class="p">]</span> + <span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> + <span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">graph</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_zenith_key</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">z</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/training/loss_functions.html b/_modules/graphnet/training/loss_functions.html new file mode 100644 index 000000000..dd70fd272 --- /dev/null +++ b/_modules/graphnet/training/loss_functions.html @@ -0,0 +1,859 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.training.loss_functions — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" /> + <script src="../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../about.html" /> + <link rel="index" title="Index" href="../../../genindex.html" /> + <link rel="search" title="Search" href="../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/training/loss_functions" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.training.loss_functions </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../"versions.json"", + target_loc = "../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-training-loss-functions--page-root">Source code for graphnet.training.loss_functions</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Collection of loss functions.</span> + +<span class="sd">All loss functions inherit from `LossFunction` which ensures a common syntax,</span> +<span class="sd">handles per-event weights, etc.</span> +<span class="sd">"""</span> + +<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span> + +<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> +<span class="kn">import</span> <span class="nn">scipy.special</span> +<span class="kn">import</span> <span class="nn">torch</span> +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span> +<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span> +<span class="kn">from</span> <span class="nn">torch.nn.functional</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">one_hot</span><span class="p">,</span> + <span class="n">cross_entropy</span><span class="p">,</span> + <span class="n">binary_cross_entropy</span><span class="p">,</span> + <span class="n">softplus</span><span class="p">,</span> +<span class="p">)</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.config</span> <span class="kn">import</span> <span class="n">save_model_config</span> +<span class="kn">from</span> <span class="nn">graphnet.models.model</span> <span class="kn">import</span> <span class="n">Model</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.decorators</span> <span class="kn">import</span> <span class="n">final</span> + + +<div class="viewcode-block" id="LossFunction"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction">[docs]</a> +<span class="k">class</span> <span class="nc">LossFunction</span><span class="p">(</span><span class="n">Model</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Base class for loss functions in `graphnet`."""</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct `LossFunction`, saving model config."""</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> + +<div class="viewcode-block" id="LossFunction.forward"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction.forward">[docs]</a> + <span class="nd">@final</span> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span> <span class="c1"># type: ignore[override]</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> + <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> + <span class="n">weights</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">return_elements</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Forward pass for all loss functions.</span> + +<span class="sd"> Args:</span> +<span class="sd"> prediction: Tensor containing predictions. Shape [N,P]</span> +<span class="sd"> target: Tensor containing targets. Shape [N,T]</span> +<span class="sd"> return_elements: Whether elementwise loss terms should be returned.</span> +<span class="sd"> The alternative is to return the averaged loss across examples.</span> + +<span class="sd"> Returns:</span> +<span class="sd"> Loss, either averaged to a scalar (if `return_elements = False`) or</span> +<span class="sd"> elementwise terms with shape [N,] (if `return_elements = True`).</span> +<span class="sd"> """</span> + <span class="n">elements</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_forward</span><span class="p">(</span><span class="n">prediction</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> + <span class="k">if</span> <span class="n">weights</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">elements</span> <span class="o">=</span> <span class="n">elements</span> <span class="o">*</span> <span class="n">weights</span> + <span class="k">assert</span> <span class="n">elements</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="n">target</span><span class="o">.</span><span class="n">size</span><span class="p">(</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">0</span> + <span class="p">),</span> <span class="s2">"`_forward` should return elementwise loss terms."</span> + + <span class="k">return</span> <span class="n">elements</span> <span class="k">if</span> <span class="n">return_elements</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">elements</span><span class="p">)</span></div> + + + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Syntax like `.forward`, for implentation in inheriting classes."""</span></div> + + + +<div class="viewcode-block" id="MSELoss"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.MSELoss">[docs]</a> +<span class="k">class</span> <span class="nc">MSELoss</span><span class="p">(</span><span class="n">LossFunction</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Mean squared error loss."""</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Implement loss calculation."""</span> + <span class="c1"># Check(s)</span> + <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span> + <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">size</span><span class="p">()</span> <span class="o">==</span> <span class="n">target</span><span class="o">.</span><span class="n">size</span><span class="p">()</span> + + <span class="n">elements</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">((</span><span class="n">prediction</span> <span class="o">-</span> <span class="n">target</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> + <span class="k">return</span> <span class="n">elements</span></div> + + + +<div class="viewcode-block" id="RMSELoss"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.RMSELoss">[docs]</a> +<span class="k">class</span> <span class="nc">RMSELoss</span><span class="p">(</span><span class="n">MSELoss</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Root mean squared error loss."""</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Implement loss calculation."""</span> + <span class="c1"># Check(s)</span> + <span class="n">elements</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_forward</span><span class="p">(</span><span class="n">prediction</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> + <span class="n">elements</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">elements</span><span class="p">)</span> + <span class="k">return</span> <span class="n">elements</span></div> + + + +<div class="viewcode-block" id="LogCoshLoss"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCoshLoss">[docs]</a> +<span class="k">class</span> <span class="nc">LogCoshLoss</span><span class="p">(</span><span class="n">LossFunction</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Log-cosh loss function.</span> + +<span class="sd"> Acts like x^2 for small x; and like |x| for large x.</span> +<span class="sd"> """</span> + + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">_log_cosh</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> <span class="c1"># pylint: disable=invalid-name</span> +<span class="w"> </span><span class="sd">"""Numerically stable version on log(cosh(x)).</span> + +<span class="sd"> Used to avoid `inf` for even moderately large differences.</span> +<span class="sd"> See [https://github.com/keras-team/keras/blob/v2.6.0/keras/losses.py#L1580-L1617]</span> +<span class="sd"> """</span> + <span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="n">softplus</span><span class="p">(</span><span class="o">-</span><span class="mf">2.0</span> <span class="o">*</span> <span class="n">x</span><span class="p">)</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mf">2.0</span><span class="p">)</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Implement loss calculation."""</span> + <span class="n">diff</span> <span class="o">=</span> <span class="n">prediction</span> <span class="o">-</span> <span class="n">target</span> + <span class="n">elements</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_log_cosh</span><span class="p">(</span><span class="n">diff</span><span class="p">)</span> + <span class="k">return</span> <span class="n">elements</span></div> + + + +<div class="viewcode-block" id="CrossEntropyLoss"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.CrossEntropyLoss">[docs]</a> +<span class="k">class</span> <span class="nc">CrossEntropyLoss</span><span class="p">(</span><span class="n">LossFunction</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Compute cross-entropy loss for classification tasks.</span> + +<span class="sd"> Predictions are an [N, num_class]-matrix of logits (i.e., non-softmax'ed</span> +<span class="sd"> probabilities), and targets are an [N,1]-matrix with integer values in</span> +<span class="sd"> (0, num_classes - 1).</span> +<span class="sd"> """</span> + + <span class="nd">@save_model_config</span> + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">options</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="n">Any</span><span class="p">],</span> <span class="n">Dict</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="nb">int</span><span class="p">]],</span> + <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> + <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> + <span class="p">):</span> +<span class="w"> </span><span class="sd">"""Construct CrossEntropyLoss."""</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> + + <span class="c1"># Member variables</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_options</span> <span class="o">=</span> <span class="n">options</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_nb_classes</span><span class="p">:</span> <span class="nb">int</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span> + <span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_options</span> <span class="ow">in</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">]</span> + <span class="k">assert</span> <span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_options</span> <span class="o">>=</span> <span class="mi">2</span> + <span class="p">),</span> <span class="sa">f</span><span class="s2">"Minimum of two classes required. Got </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="si">}</span><span class="s2">."</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_nb_classes</span> <span class="o">=</span> <span class="n">options</span> <span class="c1"># type: ignore</span> + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_nb_classes</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">)</span> <span class="c1"># type: ignore</span> + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_nb_classes</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span> + <span class="n">np</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="o">.</span><span class="n">values</span><span class="p">()))</span> + <span class="p">)</span> <span class="c1"># type: ignore</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> + <span class="sa">f</span><span class="s2">"Class options of type </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">)</span><span class="si">}</span><span class="s2"> not supported"</span> + <span class="p">)</span> + + <span class="bp">self</span><span class="o">.</span><span class="n">_loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s2">"none"</span><span class="p">)</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Transform outputs to angle and prepare prediction."""</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span> + <span class="c1"># Integer number of classes: Targets are expected to be in</span> + <span class="c1"># (0, nb_classes - 1).</span> + + <span class="c1"># Target integers are positive</span> + <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">target</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">)</span> + + <span class="c1"># Target integers are consistent with the expected number of class.</span> + <span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">target</span> <span class="o"><</span> <span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">)</span> + + <span class="k">assert</span> <span class="n">target</span><span class="o">.</span><span class="n">dtype</span> <span class="ow">in</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">]</span> + <span class="n">target_integer</span> <span class="o">=</span> <span class="n">target</span> + + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> + <span class="c1"># List of classes: Mapping target classes in list onto</span> + <span class="c1"># (0, nb_classes - 1). Example:</span> + <span class="c1"># Given options: [1, 12, 13, ...]</span> + <span class="c1"># Yields: [1, 13, 12] -> [0, 2, 1, ...]</span> + <span class="n">target_integer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span> + <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="n">value</span><span class="p">)</span> <span class="k">for</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">target</span><span class="p">]</span> + <span class="p">)</span> + + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span> + <span class="c1"># Dictionary of classes: Mapping target classes in dict onto</span> + <span class="c1"># (0, nb_classes - 1). Example:</span> + <span class="c1"># Given options: {1: 0, -1: 0, 12: 1, -12: 1, ...}</span> + <span class="c1"># Yields: [1, -1, -12, ...] -> [0, 0, 1, ...]</span> + <span class="n">target_integer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span> + <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_options</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="n">value</span><span class="p">)]</span> <span class="k">for</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">target</span><span class="p">]</span> + <span class="p">)</span> + + <span class="k">else</span><span class="p">:</span> + <span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">"Shouldn't reach here."</span> + + <span class="n">target_one_hot</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="n">one_hot</span><span class="p">(</span><span class="n">target_integer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_nb_classes</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span> + <span class="n">prediction</span><span class="o">.</span><span class="n">device</span> + <span class="p">)</span> + + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_loss</span><span class="p">(</span><span class="n">prediction</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">target_one_hot</span><span class="o">.</span><span class="n">float</span><span class="p">())</span></div> + + + +<div class="viewcode-block" id="BinaryCrossEntropyLoss"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.BinaryCrossEntropyLoss">[docs]</a> +<span class="k">class</span> <span class="nc">BinaryCrossEntropyLoss</span><span class="p">(</span><span class="n">LossFunction</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Compute binary cross entropy loss.</span> + +<span class="sd"> Predictions are vector probabilities (i.e., values between 0 and 1), and</span> +<span class="sd"> targets should be 0 and 1.</span> +<span class="sd"> """</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="k">return</span> <span class="n">binary_cross_entropy</span><span class="p">(</span> + <span class="n">prediction</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">target</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">reduction</span><span class="o">=</span><span class="s2">"none"</span> + <span class="p">)</span></div> + + + +<div class="viewcode-block" id="LogCMK"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCMK">[docs]</a> +<span class="k">class</span> <span class="nc">LogCMK</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""MIT License.</span> + +<span class="sd"> Copyright (c) 2019 Max Ryabinin</span> + +<span class="sd"> Permission is hereby granted, free of charge, to any person obtaining a copy</span> +<span class="sd"> of this software and associated documentation files (the "Software"), to deal</span> +<span class="sd"> in the Software without restriction, including without limitation the rights</span> +<span class="sd"> to use, copy, modify, merge, publish, distribute, sublicense, and/or sell</span> +<span class="sd"> copies of the Software, and to permit persons to whom the Software is</span> +<span class="sd"> furnished to do so, subject to the following conditions:</span> + +<span class="sd"> The above copyright notice and this permission notice shall be included in all</span> +<span class="sd"> copies or substantial portions of the Software.</span> + +<span class="sd"> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR</span> +<span class="sd"> IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,</span> +<span class="sd"> FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE</span> +<span class="sd"> AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER</span> +<span class="sd"> LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,</span> +<span class="sd"> OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE</span> +<span class="sd"> SOFTWARE.</span> +<span class="sd"> _____________________</span> + +<span class="sd"> From [https://github.com/mryab/vmf_loss/blob/master/losses.py]</span> +<span class="sd"> Modified to use modified Bessel function instead of exponentially scaled ditto</span> +<span class="sd"> (i.e. `.ive` -> `.iv`) as indiciated in [1812.04616] in spite of suggestion in</span> +<span class="sd"> Sec. 8.2 of this paper. The change has been validated through comparison with</span> +<span class="sd"> exact calculations for `m=2` and `m=3` and found to yield the correct results.</span> +<span class="sd"> """</span> + +<div class="viewcode-block" id="LogCMK.forward"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCMK.forward">[docs]</a> + <span class="nd">@staticmethod</span> + <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span> + <span class="n">ctx</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">m</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">kappa</span><span class="p">:</span> <span class="n">Tensor</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> <span class="c1"># pylint: disable=invalid-name,arguments-differ</span> +<span class="w"> </span><span class="sd">"""Forward pass."""</span> + <span class="n">dtype</span> <span class="o">=</span> <span class="n">kappa</span><span class="o">.</span><span class="n">dtype</span> + <span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">kappa</span><span class="p">)</span> + <span class="n">ctx</span><span class="o">.</span><span class="n">m</span> <span class="o">=</span> <span class="n">m</span> + <span class="n">ctx</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span> + <span class="n">kappa</span> <span class="o">=</span> <span class="n">kappa</span><span class="o">.</span><span class="n">double</span><span class="p">()</span> + <span class="n">iv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span> + <span class="n">scipy</span><span class="o">.</span><span class="n">special</span><span class="o">.</span><span class="n">iv</span><span class="p">(</span><span class="n">m</span> <span class="o">/</span> <span class="mf">2.0</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">kappa</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> + <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">kappa</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> + <span class="k">return</span> <span class="p">(</span> + <span class="p">(</span><span class="n">m</span> <span class="o">/</span> <span class="mf">2.0</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">kappa</span><span class="p">)</span> + <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">iv</span><span class="p">)</span> + <span class="o">-</span> <span class="p">(</span><span class="n">m</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">pi</span><span class="p">)</span> + <span class="p">)</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div> + + +<div class="viewcode-block" id="LogCMK.backward"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCMK.backward">[docs]</a> + <span class="nd">@staticmethod</span> + <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span> + <span class="n">ctx</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">grad_output</span><span class="p">:</span> <span class="n">Tensor</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> <span class="c1"># pylint: disable=invalid-name,arguments-differ</span> +<span class="w"> </span><span class="sd">"""Backward pass."""</span> + <span class="n">kappa</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> + <span class="n">m</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">m</span> + <span class="n">dtype</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">dtype</span> + <span class="n">kappa</span> <span class="o">=</span> <span class="n">kappa</span><span class="o">.</span><span class="n">double</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> + <span class="n">grads</span> <span class="o">=</span> <span class="o">-</span><span class="p">(</span> + <span class="p">(</span><span class="n">scipy</span><span class="o">.</span><span class="n">special</span><span class="o">.</span><span class="n">iv</span><span class="p">(</span><span class="n">m</span> <span class="o">/</span> <span class="mf">2.0</span><span class="p">,</span> <span class="n">kappa</span><span class="p">))</span> + <span class="o">/</span> <span class="p">(</span><span class="n">scipy</span><span class="o">.</span><span class="n">special</span><span class="o">.</span><span class="n">iv</span><span class="p">(</span><span class="n">m</span> <span class="o">/</span> <span class="mf">2.0</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">kappa</span><span class="p">))</span> + <span class="p">)</span> + <span class="k">return</span> <span class="p">(</span> + <span class="kc">None</span><span class="p">,</span> + <span class="n">grad_output</span> + <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">grads</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">grad_output</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span> + <span class="p">)</span></div> +</div> + + + +<div class="viewcode-block" id="VonMisesFisherLoss"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss">[docs]</a> +<span class="k">class</span> <span class="nc">VonMisesFisherLoss</span><span class="p">(</span><span class="n">LossFunction</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""General class for calculating von Mises-Fisher loss.</span> + +<span class="sd"> Requires implementation for specific dimension `m` in which the target and</span> +<span class="sd"> prediction vectors need to be prepared.</span> +<span class="sd"> """</span> + +<div class="viewcode-block" id="VonMisesFisherLoss.log_cmk_exact"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact">[docs]</a> + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">log_cmk_exact</span><span class="p">(</span> + <span class="bp">cls</span><span class="p">,</span> <span class="n">m</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">kappa</span><span class="p">:</span> <span class="n">Tensor</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> <span class="c1"># pylint: disable=invalid-name</span> +<span class="w"> </span><span class="sd">"""Calculate $log C_{m}(k)$ term in von Mises-Fisher loss exactly."""</span> + <span class="k">return</span> <span class="n">LogCMK</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">kappa</span><span class="p">)</span></div> + + +<div class="viewcode-block" id="VonMisesFisherLoss.log_cmk_approx"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx">[docs]</a> + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">log_cmk_approx</span><span class="p">(</span> + <span class="bp">cls</span><span class="p">,</span> <span class="n">m</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">kappa</span><span class="p">:</span> <span class="n">Tensor</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> <span class="c1"># pylint: disable=invalid-name</span> +<span class="w"> </span><span class="sd">"""Calculate $log C_{m}(k)$ term in von Mises-Fisher loss approx.</span> + +<span class="sd"> [https://arxiv.org/abs/1812.04616] Sec. 8.2 with additional minus sign.</span> +<span class="sd"> """</span> + <span class="n">v</span> <span class="o">=</span> <span class="n">m</span> <span class="o">/</span> <span class="mf">2.0</span> <span class="o">-</span> <span class="mf">0.5</span> + <span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">((</span><span class="n">v</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="n">kappa</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span> + <span class="n">b</span> <span class="o">=</span> <span class="n">v</span> <span class="o">-</span> <span class="mi">1</span> + <span class="k">return</span> <span class="o">-</span><span class="n">a</span> <span class="o">+</span> <span class="n">b</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">b</span> <span class="o">+</span> <span class="n">a</span><span class="p">)</span></div> + + +<div class="viewcode-block" id="VonMisesFisherLoss.log_cmk"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk">[docs]</a> + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">log_cmk</span><span class="p">(</span> + <span class="bp">cls</span><span class="p">,</span> <span class="n">m</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">kappa</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">kappa_switch</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">100.0</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> <span class="c1"># pylint: disable=invalid-name</span> +<span class="w"> </span><span class="sd">"""Calculate $log C_{m}(k)$ term in von Mises-Fisher loss.</span> + +<span class="sd"> Since `log_cmk_exact` is diverges for `kappa` >~ 700 (using float64</span> +<span class="sd"> precision), and since `log_cmk_approx` is unaccurate for small `kappa`,</span> +<span class="sd"> this method automatically switches between the two at `kappa_switch`,</span> +<span class="sd"> ensuring continuity at this point.</span> +<span class="sd"> """</span> + <span class="n">kappa_switch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">kappa_switch</span><span class="p">])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">kappa</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> + <span class="n">mask_exact</span> <span class="o">=</span> <span class="n">kappa</span> <span class="o"><</span> <span class="n">kappa_switch</span> + + <span class="c1"># Ensure continuity at `kappa_switch`</span> + <span class="n">offset</span> <span class="o">=</span> <span class="bp">cls</span><span class="o">.</span><span class="n">log_cmk_approx</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">kappa_switch</span><span class="p">)</span> <span class="o">-</span> <span class="bp">cls</span><span class="o">.</span><span class="n">log_cmk_exact</span><span class="p">(</span> + <span class="n">m</span><span class="p">,</span> <span class="n">kappa_switch</span> + <span class="p">)</span> + <span class="n">ret</span> <span class="o">=</span> <span class="bp">cls</span><span class="o">.</span><span class="n">log_cmk_approx</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">kappa</span><span class="p">)</span> <span class="o">-</span> <span class="n">offset</span> + <span class="n">ret</span><span class="p">[</span><span class="n">mask_exact</span><span class="p">]</span> <span class="o">=</span> <span class="bp">cls</span><span class="o">.</span><span class="n">log_cmk_exact</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">kappa</span><span class="p">[</span><span class="n">mask_exact</span><span class="p">])</span> + <span class="k">return</span> <span class="n">ret</span></div> + + + <span class="k">def</span> <span class="nf">_evaluate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Calculate von Mises-Fisher loss for a vector in D dimensons.</span> + +<span class="sd"> This loss utilises the von Mises-Fisher distribution, which is a</span> +<span class="sd"> probability distribution on the (D - 1) sphere in D-dimensional space.</span> + +<span class="sd"> Args:</span> +<span class="sd"> prediction: Predicted vector, of shape [batch_size, D].</span> +<span class="sd"> target: Target unit vector, of shape [batch_size, D].</span> + +<span class="sd"> Returns:</span> +<span class="sd"> Elementwise von Mises-Fisher loss terms.</span> +<span class="sd"> """</span> + <span class="c1"># Check(s)</span> + <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span> + <span class="k">assert</span> <span class="n">target</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span> + <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">size</span><span class="p">()</span> <span class="o">==</span> <span class="n">target</span><span class="o">.</span><span class="n">size</span><span class="p">()</span> + + <span class="c1"># Computing loss</span> + <span class="n">m</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span> + <span class="n">k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">prediction</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> + <span class="n">dotprod</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">prediction</span> <span class="o">*</span> <span class="n">target</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> + <span class="n">elements</span> <span class="o">=</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">log_cmk</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="o">-</span> <span class="n">dotprod</span> + <span class="k">return</span> <span class="n">elements</span> + + <span class="nd">@abstractmethod</span> + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> + <span class="k">raise</span> <span class="ne">NotImplementedError</span></div> + + + +<div class="viewcode-block" id="VonMisesFisher2DLoss"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisher2DLoss">[docs]</a> +<span class="k">class</span> <span class="nc">VonMisesFisher2DLoss</span><span class="p">(</span><span class="n">VonMisesFisherLoss</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""von Mises-Fisher loss function vectors in the 2D plane."""</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Calculate von Mises-Fisher loss for an angle in the 2D plane.</span> + +<span class="sd"> Args:</span> +<span class="sd"> prediction: Output of the model. Must have shape [N, 2] where 0th</span> +<span class="sd"> column is a prediction of `angle` and 1st column is an estimate</span> +<span class="sd"> of `kappa`.</span> +<span class="sd"> target: Target tensor, extracted from graph object.</span> + +<span class="sd"> Returns:</span> +<span class="sd"> loss: Elementwise von Mises-Fisher loss terms. Shape [N,]</span> +<span class="sd"> """</span> + <span class="c1"># Check(s)</span> + <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span> <span class="ow">and</span> <span class="n">prediction</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">2</span> + <span class="k">assert</span> <span class="n">target</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span> + <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">target</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span> + + <span class="c1"># Formatting target</span> + <span class="n">angle_true</span> <span class="o">=</span> <span class="n">target</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> + <span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span> + <span class="p">[</span> + <span class="n">torch</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">angle_true</span><span class="p">),</span> + <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">angle_true</span><span class="p">),</span> + <span class="p">],</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="c1"># Formatting prediction</span> + <span class="n">angle_pred</span> <span class="o">=</span> <span class="n">prediction</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> + <span class="n">kappa</span> <span class="o">=</span> <span class="n">prediction</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> + <span class="n">p</span> <span class="o">=</span> <span class="n">kappa</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span> + <span class="p">[</span> + <span class="n">torch</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">angle_pred</span><span class="p">),</span> + <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">angle_pred</span><span class="p">),</span> + <span class="p">],</span> + <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_evaluate</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></div> + + + +<div class="viewcode-block" id="EuclideanDistanceLoss"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.EuclideanDistanceLoss">[docs]</a> +<span class="k">class</span> <span class="nc">EuclideanDistanceLoss</span><span class="p">(</span><span class="n">LossFunction</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Mean squared error in three dimensions."""</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Calculate 3D Euclidean distance between predicted and target.</span> + +<span class="sd"> Args:</span> +<span class="sd"> prediction: Output of the model. Must have shape [N, 3]</span> +<span class="sd"> target: Target tensor, extracted from graph object.</span> + +<span class="sd"> Returns:</span> +<span class="sd"> Elementwise von Mises-Fisher loss terms. Shape [N,]</span> +<span class="sd"> """</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span> + <span class="p">(</span><span class="n">prediction</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">target</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span> <span class="o">**</span> <span class="mi">2</span> + <span class="o">+</span> <span class="p">(</span><span class="n">prediction</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">target</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span> <span class="o">**</span> <span class="mi">2</span> + <span class="o">+</span> <span class="p">(</span><span class="n">prediction</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">target</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">])</span> <span class="o">**</span> <span class="mi">2</span> + <span class="p">)</span></div> + + + +<div class="viewcode-block" id="VonMisesFisher3DLoss"> +<a class="viewcode-back" href="../../../api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisher3DLoss">[docs]</a> +<span class="k">class</span> <span class="nc">VonMisesFisher3DLoss</span><span class="p">(</span><span class="n">VonMisesFisherLoss</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""von Mises-Fisher loss function vectors in the 3D plane."""</span> + + <span class="k">def</span> <span class="nf">_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Calculate von Mises-Fisher loss for a direction in the 3D.</span> + +<span class="sd"> Args:</span> +<span class="sd"> prediction: Output of the model. Must have shape [N, 4] where</span> +<span class="sd"> columns 0, 1, 2 are predictions of `direction` and last column</span> +<span class="sd"> is an estimate of `kappa`.</span> +<span class="sd"> target: Target tensor, extracted from graph object.</span> + +<span class="sd"> Returns:</span> +<span class="sd"> Elementwise von Mises-Fisher loss terms. Shape [N,]</span> +<span class="sd"> """</span> + <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> + <span class="c1"># Check(s)</span> + <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span> <span class="ow">and</span> <span class="n">prediction</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">4</span> + <span class="k">assert</span> <span class="n">target</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span> + <span class="k">assert</span> <span class="n">prediction</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">target</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span> + + <span class="n">kappa</span> <span class="o">=</span> <span class="n">prediction</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> + <span class="n">p</span> <span class="o">=</span> <span class="n">kappa</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">prediction</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_evaluate</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/training/utils.html b/_modules/graphnet/training/utils.html new file mode 100644 index 000000000..eff87f698 --- /dev/null +++ b/_modules/graphnet/training/utils.html @@ -0,0 +1,656 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.training.utils — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" /> + <script src="../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../about.html" /> + <link rel="index" title="Index" href="../../../genindex.html" /> + <link rel="search" title="Search" href="../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/training/utils" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.training.utils </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../"versions.json"", + target_loc = "../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-training-utils--page-root">Source code for graphnet.training.utils</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Utility functions for `graphnet.training`."""</span> + +<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">OrderedDict</span> +<span class="kn">import</span> <span class="nn">os</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Callable</span> + +<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> +<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> +<span class="kn">from</span> <span class="nn">pytorch_lightning</span> <span class="kn">import</span> <span class="n">Trainer</span> +<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span> +<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span> +<span class="kn">from</span> <span class="nn">torch_geometric.data</span> <span class="kn">import</span> <span class="n">Batch</span><span class="p">,</span> <span class="n">Data</span> + +<span class="kn">from</span> <span class="nn">graphnet.data.dataset</span> <span class="kn">import</span> <span class="n">Dataset</span> +<span class="kn">from</span> <span class="nn">graphnet.data.dataset</span> <span class="kn">import</span> <span class="n">SQLiteDataset</span> +<span class="kn">from</span> <span class="nn">graphnet.data.dataset</span> <span class="kn">import</span> <span class="n">ParquetDataset</span> +<span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span> +<span class="kn">from</span> <span class="nn">graphnet.models.graphs</span> <span class="kn">import</span> <span class="n">GraphDefinition</span> + + +<div class="viewcode-block" id="collate_fn"> +<a class="viewcode-back" href="../../../api/graphnet.training.utils.html#graphnet.training.utils.collate_fn">[docs]</a> +<span class="k">def</span> <span class="nf">collate_fn</span><span class="p">(</span><span class="n">graphs</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Data</span><span class="p">])</span> <span class="o">-></span> <span class="n">Batch</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Remove graphs with less than two DOM hits.</span> + +<span class="sd"> Should not occur in "production.</span> +<span class="sd"> """</span> + <span class="n">graphs</span> <span class="o">=</span> <span class="p">[</span><span class="n">g</span> <span class="k">for</span> <span class="n">g</span> <span class="ow">in</span> <span class="n">graphs</span> <span class="k">if</span> <span class="n">g</span><span class="o">.</span><span class="n">n_pulses</span> <span class="o">></span> <span class="mi">1</span><span class="p">]</span> + <span class="k">return</span> <span class="n">Batch</span><span class="o">.</span><span class="n">from_data_list</span><span class="p">(</span><span class="n">graphs</span><span class="p">)</span></div> + + + +<span class="c1"># @TODO: Remove in favour of DataLoader{,.from_dataset_config}</span> +<div class="viewcode-block" id="make_dataloader"> +<a class="viewcode-back" href="../../../api/graphnet.training.utils.html#graphnet.training.utils.make_dataloader">[docs]</a> +<span class="k">def</span> <span class="nf">make_dataloader</span><span class="p">(</span> + <span class="n">db</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">pulsemaps</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span> + <span class="n">graph_definition</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">GraphDefinition</span><span class="p">],</span> + <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="n">truth</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="o">*</span><span class="p">,</span> + <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> + <span class="n">shuffle</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> + <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">num_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> + <span class="n">persistent_workers</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> + <span class="n">node_truth</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">truth_table</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"truth"</span><span class="p">,</span> + <span class="n">node_truth_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">string_selection</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight_column</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">index_column</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"event_no"</span><span class="p">,</span> + <span class="n">labels</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">DataLoader</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct `DataLoader` instance."""</span> + <span class="c1"># Check(s)</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pulsemaps</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span> + <span class="n">pulsemaps</span> <span class="o">=</span> <span class="p">[</span><span class="n">pulsemaps</span><span class="p">]</span> + + <span class="n">dataset</span> <span class="o">=</span> <span class="n">SQLiteDataset</span><span class="p">(</span> + <span class="n">path</span><span class="o">=</span><span class="n">db</span><span class="p">,</span> + <span class="n">pulsemaps</span><span class="o">=</span><span class="n">pulsemaps</span><span class="p">,</span> + <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">,</span> + <span class="n">truth</span><span class="o">=</span><span class="n">truth</span><span class="p">,</span> + <span class="n">selection</span><span class="o">=</span><span class="n">selection</span><span class="p">,</span> + <span class="n">node_truth</span><span class="o">=</span><span class="n">node_truth</span><span class="p">,</span> + <span class="n">truth_table</span><span class="o">=</span><span class="n">truth_table</span><span class="p">,</span> + <span class="n">node_truth_table</span><span class="o">=</span><span class="n">node_truth_table</span><span class="p">,</span> + <span class="n">string_selection</span><span class="o">=</span><span class="n">string_selection</span><span class="p">,</span> + <span class="n">loss_weight_table</span><span class="o">=</span><span class="n">loss_weight_table</span><span class="p">,</span> + <span class="n">loss_weight_column</span><span class="o">=</span><span class="n">loss_weight_column</span><span class="p">,</span> + <span class="n">index_column</span><span class="o">=</span><span class="n">index_column</span><span class="p">,</span> + <span class="n">graph_definition</span><span class="o">=</span><span class="n">graph_definition</span><span class="p">,</span> + <span class="p">)</span> + + <span class="c1"># adds custom labels to dataset</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span> + <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">labels</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span> + <span class="n">dataset</span><span class="o">.</span><span class="n">add_label</span><span class="p">(</span><span class="n">key</span><span class="o">=</span><span class="n">label</span><span class="p">,</span> <span class="n">fn</span><span class="o">=</span><span class="n">labels</span><span class="p">[</span><span class="n">label</span><span class="p">])</span> + + <span class="n">dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span> + <span class="n">dataset</span><span class="p">,</span> + <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> + <span class="n">shuffle</span><span class="o">=</span><span class="n">shuffle</span><span class="p">,</span> + <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span> + <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">,</span> + <span class="n">persistent_workers</span><span class="o">=</span><span class="n">persistent_workers</span><span class="p">,</span> + <span class="n">prefetch_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> + <span class="p">)</span> + + <span class="k">return</span> <span class="n">dataloader</span></div> + + + +<span class="c1"># @TODO: Remove in favour of DataLoader{,.from_dataset_config}</span> +<div class="viewcode-block" id="make_train_validation_dataloader"> +<a class="viewcode-back" href="../../../api/graphnet.training.utils.html#graphnet.training.utils.make_train_validation_dataloader">[docs]</a> +<span class="k">def</span> <span class="nf">make_train_validation_dataloader</span><span class="p">(</span> + <span class="n">db</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> + <span class="n">graph_definition</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">GraphDefinition</span><span class="p">],</span> + <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span> + <span class="n">pulsemaps</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span> + <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="n">truth</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="o">*</span><span class="p">,</span> + <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> + <span class="n">database_indices</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">seed</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">42</span><span class="p">,</span> + <span class="n">test_size</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.33</span><span class="p">,</span> + <span class="n">num_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> + <span class="n">persistent_workers</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> + <span class="n">node_truth</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">truth_table</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"truth"</span><span class="p">,</span> + <span class="n">node_truth_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">string_selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight_column</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">loss_weight_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="n">index_column</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"event_no"</span><span class="p">,</span> + <span class="n">labels</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Callable</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">DataLoader</span><span class="p">,</span> <span class="n">DataLoader</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Construct train and test `DataLoader` instances."""</span> + <span class="c1"># Reproducibility</span> + <span class="n">rng</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">default_rng</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">)</span> + <span class="c1"># Checks(s)</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pulsemaps</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span> + <span class="n">pulsemaps</span> <span class="o">=</span> <span class="p">[</span><span class="n">pulsemaps</span><span class="p">]</span> + + <span class="k">if</span> <span class="n">selection</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="c1"># If no selection is provided, use all events in dataset.</span> + <span class="n">dataset</span><span class="p">:</span> <span class="n">Dataset</span> + <span class="k">if</span> <span class="n">db</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s2">".db"</span><span class="p">):</span> + <span class="n">dataset</span> <span class="o">=</span> <span class="n">SQLiteDataset</span><span class="p">(</span> + <span class="n">path</span><span class="o">=</span><span class="n">db</span><span class="p">,</span> + <span class="n">graph_definition</span><span class="o">=</span><span class="n">graph_definition</span><span class="p">,</span> + <span class="n">pulsemaps</span><span class="o">=</span><span class="n">pulsemaps</span><span class="p">,</span> + <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">,</span> + <span class="n">truth</span><span class="o">=</span><span class="n">truth</span><span class="p">,</span> + <span class="n">truth_table</span><span class="o">=</span><span class="n">truth_table</span><span class="p">,</span> + <span class="n">index_column</span><span class="o">=</span><span class="n">index_column</span><span class="p">,</span> + <span class="p">)</span> + <span class="k">elif</span> <span class="n">db</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s2">".parquet"</span><span class="p">):</span> + <span class="n">dataset</span> <span class="o">=</span> <span class="n">ParquetDataset</span><span class="p">(</span> + <span class="n">path</span><span class="o">=</span><span class="n">db</span><span class="p">,</span> + <span class="n">graph_definition</span><span class="o">=</span><span class="n">graph_definition</span><span class="p">,</span> + <span class="n">pulsemaps</span><span class="o">=</span><span class="n">pulsemaps</span><span class="p">,</span> + <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">,</span> + <span class="n">truth</span><span class="o">=</span><span class="n">truth</span><span class="p">,</span> + <span class="n">truth_table</span><span class="o">=</span><span class="n">truth_table</span><span class="p">,</span> + <span class="n">index_column</span><span class="o">=</span><span class="n">index_column</span><span class="p">,</span> + <span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span> + <span class="sa">f</span><span class="s2">"File </span><span class="si">{</span><span class="n">db</span><span class="si">}</span><span class="s2"> with format </span><span class="si">{</span><span class="n">db</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">'.'</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="si">}</span><span class="s2"> not supported."</span> + <span class="p">)</span> + <span class="n">selection</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">_get_all_indices</span><span class="p">()</span> + + <span class="c1"># Perform train/validation split</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">db</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> + <span class="n">df_for_shuffle</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">(</span> + <span class="p">{</span><span class="s2">"event_no"</span><span class="p">:</span> <span class="n">selection</span><span class="p">,</span> <span class="s2">"db"</span><span class="p">:</span> <span class="n">database_indices</span><span class="p">}</span> + <span class="p">)</span> + <span class="n">shuffled_df</span> <span class="o">=</span> <span class="n">df_for_shuffle</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span> + <span class="n">frac</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="n">rng</span> + <span class="p">)</span> + <span class="n">training_df</span><span class="p">,</span> <span class="n">validation_df</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span> + <span class="n">shuffled_df</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="n">test_size</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="n">seed</span> + <span class="p">)</span> + <span class="n">training_selection</span> <span class="o">=</span> <span class="n">training_df</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> + <span class="n">validation_selection</span> <span class="o">=</span> <span class="n">validation_df</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">training_selection</span><span class="p">,</span> <span class="n">validation_selection</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span> + <span class="n">selection</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="n">test_size</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="n">seed</span> + <span class="p">)</span> + + <span class="c1"># Create DataLoaders</span> + <span class="n">common_kwargs</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span> + <span class="n">db</span><span class="o">=</span><span class="n">db</span><span class="p">,</span> + <span class="n">pulsemaps</span><span class="o">=</span><span class="n">pulsemaps</span><span class="p">,</span> + <span class="n">features</span><span class="o">=</span><span class="n">features</span><span class="p">,</span> + <span class="n">truth</span><span class="o">=</span><span class="n">truth</span><span class="p">,</span> + <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> + <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span> + <span class="n">persistent_workers</span><span class="o">=</span><span class="n">persistent_workers</span><span class="p">,</span> + <span class="n">node_truth</span><span class="o">=</span><span class="n">node_truth</span><span class="p">,</span> + <span class="n">truth_table</span><span class="o">=</span><span class="n">truth_table</span><span class="p">,</span> + <span class="n">node_truth_table</span><span class="o">=</span><span class="n">node_truth_table</span><span class="p">,</span> + <span class="n">string_selection</span><span class="o">=</span><span class="n">string_selection</span><span class="p">,</span> + <span class="n">loss_weight_column</span><span class="o">=</span><span class="n">loss_weight_column</span><span class="p">,</span> + <span class="n">loss_weight_table</span><span class="o">=</span><span class="n">loss_weight_table</span><span class="p">,</span> + <span class="n">index_column</span><span class="o">=</span><span class="n">index_column</span><span class="p">,</span> + <span class="n">labels</span><span class="o">=</span><span class="n">labels</span><span class="p">,</span> + <span class="n">graph_definition</span><span class="o">=</span><span class="n">graph_definition</span><span class="p">,</span> + <span class="p">)</span> + + <span class="n">training_dataloader</span> <span class="o">=</span> <span class="n">make_dataloader</span><span class="p">(</span> + <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> + <span class="n">selection</span><span class="o">=</span><span class="n">training_selection</span><span class="p">,</span> + <span class="o">**</span><span class="n">common_kwargs</span><span class="p">,</span> <span class="c1"># type: ignore[arg-type]</span> + <span class="p">)</span> + + <span class="n">validation_dataloader</span> <span class="o">=</span> <span class="n">make_dataloader</span><span class="p">(</span> + <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> + <span class="n">selection</span><span class="o">=</span><span class="n">validation_selection</span><span class="p">,</span> + <span class="o">**</span><span class="n">common_kwargs</span><span class="p">,</span> <span class="c1"># type: ignore[arg-type]</span> + <span class="p">)</span> + + <span class="k">return</span> <span class="p">(</span> + <span class="n">training_dataloader</span><span class="p">,</span> + <span class="n">validation_dataloader</span><span class="p">,</span> + <span class="p">)</span></div> + + + +<span class="c1"># @TODO: Remove in favour of Model.predict{,_as_dataframe}</span> +<div class="viewcode-block" id="get_predictions"> +<a class="viewcode-back" href="../../../api/graphnet.training.utils.html#graphnet.training.utils.get_predictions">[docs]</a> +<span class="k">def</span> <span class="nf">get_predictions</span><span class="p">(</span> + <span class="n">trainer</span><span class="p">:</span> <span class="n">Trainer</span><span class="p">,</span> + <span class="n">model</span><span class="p">:</span> <span class="n">Model</span><span class="p">,</span> + <span class="n">dataloader</span><span class="p">:</span> <span class="n">DataLoader</span><span class="p">,</span> + <span class="n">prediction_columns</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="o">*</span><span class="p">,</span> + <span class="n">node_level</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> + <span class="n">additional_attributes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Get `model` predictions on `dataloader`."""</span> + <span class="c1"># Gets predictions from model on the events in the dataloader.</span> + <span class="c1"># NOTE: dataloader must NOT have shuffle = True!</span> + + <span class="c1"># Check(s)</span> + <span class="k">if</span> <span class="n">additional_attributes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">additional_attributes</span> <span class="o">=</span> <span class="p">[]</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">additional_attributes</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> + + <span class="c1"># Set model to inference mode</span> + <span class="n">model</span><span class="o">.</span><span class="n">inference</span><span class="p">()</span> + + <span class="c1"># Get predictions</span> + <span class="n">predictions_torch</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">dataloader</span><span class="p">)</span> + <span class="n">predictions_list</span> <span class="o">=</span> <span class="p">[</span> + <span class="n">p</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">predictions_torch</span> + <span class="p">]</span> <span class="c1"># Assuming single task</span> + <span class="n">predictions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">predictions_list</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> + <span class="k">try</span><span class="p">:</span> + <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">prediction_columns</span><span class="p">)</span> <span class="o">==</span> <span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> + <span class="k">except</span> <span class="ne">IndexError</span><span class="p">:</span> + <span class="n">predictions</span> <span class="o">=</span> <span class="n">predictions</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> + <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">prediction_columns</span><span class="p">)</span> <span class="o">==</span> <span class="n">predictions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> + + <span class="c1"># Get additional attributes</span> + <span class="n">attributes</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]]</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">(</span> + <span class="p">[(</span><span class="n">attr</span><span class="p">,</span> <span class="p">[])</span> <span class="k">for</span> <span class="n">attr</span> <span class="ow">in</span> <span class="n">additional_attributes</span><span class="p">]</span> + <span class="p">)</span> + <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">dataloader</span><span class="p">:</span> + <span class="k">for</span> <span class="n">attr</span> <span class="ow">in</span> <span class="n">attributes</span><span class="p">:</span> + <span class="n">attribute</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">attr</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> + <span class="k">if</span> <span class="n">node_level</span><span class="p">:</span> + <span class="k">if</span> <span class="n">attr</span> <span class="o">==</span> <span class="s2">"event_no"</span><span class="p">:</span> + <span class="n">attribute</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span> + <span class="n">attribute</span><span class="p">,</span> <span class="n">batch</span><span class="p">[</span><span class="s2">"n_pulses"</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> + <span class="p">)</span> + <span class="n">attributes</span><span class="p">[</span><span class="n">attr</span><span class="p">]</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">attribute</span><span class="p">)</span> + + <span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span> + <span class="p">[</span><span class="n">predictions</span><span class="p">]</span> + <span class="o">+</span> <span class="p">[</span> + <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">values</span><span class="p">)[:,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="k">for</span> <span class="n">values</span> <span class="ow">in</span> <span class="n">attributes</span><span class="o">.</span><span class="n">values</span><span class="p">()</span> + <span class="p">],</span> + <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> + <span class="p">)</span> + + <span class="n">results</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">(</span> + <span class="n">data</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="n">prediction_columns</span> <span class="o">+</span> <span class="n">additional_attributes</span> + <span class="p">)</span> + <span class="k">return</span> <span class="n">results</span></div> + + + +<span class="c1"># @TODO: Remove</span> +<div class="viewcode-block" id="save_results"> +<a class="viewcode-back" href="../../../api/graphnet.training.utils.html#graphnet.training.utils.save_results">[docs]</a> +<span class="k">def</span> <span class="nf">save_results</span><span class="p">(</span> + <span class="n">db</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">tag</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">results</span><span class="p">:</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">,</span> <span class="n">archive</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">Model</span> +<span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Save trained model and prediction `results` in `db`."""</span> + <span class="n">db_name</span> <span class="o">=</span> <span class="n">db</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"/"</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"."</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> + <span class="n">path</span> <span class="o">=</span> <span class="n">archive</span> <span class="o">+</span> <span class="s2">"/"</span> <span class="o">+</span> <span class="n">db_name</span> <span class="o">+</span> <span class="s2">"/"</span> <span class="o">+</span> <span class="n">tag</span> + <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> + <span class="n">results</span><span class="o">.</span><span class="n">to_csv</span><span class="p">(</span><span class="n">path</span> <span class="o">+</span> <span class="s2">"/results.csv"</span><span class="p">)</span> + <span class="n">model</span><span class="o">.</span><span class="n">save_state_dict</span><span class="p">(</span><span class="n">path</span> <span class="o">+</span> <span class="s2">"/"</span> <span class="o">+</span> <span class="n">tag</span> <span class="o">+</span> <span class="s2">"_state_dict.pth"</span><span class="p">)</span> + <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">path</span> <span class="o">+</span> <span class="s2">"/"</span> <span class="o">+</span> <span class="n">tag</span> <span class="o">+</span> <span class="s2">"_model.pth"</span><span class="p">)</span> + <span class="n">Logger</span><span class="p">()</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">"Results saved at: </span><span class="se">\n</span><span class="s2"> </span><span class="si">%s</span><span class="s2">"</span> <span class="o">%</span> <span class="n">path</span><span class="p">)</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/training/weight_fitting.html b/_modules/graphnet/training/weight_fitting.html index 9811e69cf..909eb3dbe 100644 --- a/_modules/graphnet/training/weight_fitting.html +++ b/_modules/graphnet/training/weight_fitting.html @@ -556,7 +556,7 @@ <h1 id="modules-graphnet-training-weight-fitting--page-root">Source code for gra </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/utilities/argparse.html b/_modules/graphnet/utilities/argparse.html index 43877f0e5..a68f077b2 100644 --- a/_modules/graphnet/utilities/argparse.html +++ b/_modules/graphnet/utilities/argparse.html @@ -515,7 +515,7 @@ <h1 id="modules-graphnet-utilities-argparse--page-root">Source code for graphnet </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/utilities/config/base_config.html b/_modules/graphnet/utilities/config/base_config.html new file mode 100644 index 000000000..f57c1f185 --- /dev/null +++ b/_modules/graphnet/utilities/config/base_config.html @@ -0,0 +1,449 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.utilities.config.base_config — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/utilities/config/base_config" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.utilities.config.base_config </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-utilities-config-base-config--page-root">Source code for graphnet.utilities.config.base_config</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Base config class(es)."""</span> + +<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">abstractmethod</span> +<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">OrderedDict</span> +<span class="kn">import</span> <span class="nn">inspect</span> +<span class="kn">import</span> <span class="nn">sys</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Optional</span> + +<span class="kn">from</span> <span class="nn">pydantic</span> <span class="kn">import</span> <span class="n">BaseModel</span> +<span class="kn">import</span> <span class="nn">ruamel.yaml</span> <span class="k">as</span> <span class="nn">yaml</span> + + +<span class="n">CONFIG_FILES_SUFFIXES</span> <span class="o">=</span> <span class="p">(</span><span class="s2">".yml"</span><span class="p">,</span> <span class="s2">".yaml"</span><span class="p">)</span> + + +<div class="viewcode-block" id="BaseConfig"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig">[docs]</a> +<span class="k">class</span> <span class="nc">BaseConfig</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Base class for Configs."""</span> + +<div class="viewcode-block" id="BaseConfig.load"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.load">[docs]</a> + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"BaseConfig"</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Load BaseConfig from `path`."""</span> + <span class="k">assert</span> <span class="n">path</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span> + <span class="n">CONFIG_FILES_SUFFIXES</span> + <span class="p">),</span> <span class="s2">"Please specify YAML config file."</span> + <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"r"</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> + <span class="n">yaml_</span> <span class="o">=</span> <span class="n">yaml</span><span class="o">.</span><span class="n">YAML</span><span class="p">(</span><span class="n">typ</span><span class="o">=</span><span class="s2">"safe"</span><span class="p">,</span> <span class="n">pure</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> + <span class="n">config_dict</span> <span class="o">=</span> <span class="n">yaml_</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span> + + <span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">config_dict</span><span class="p">)</span></div> + + +<div class="viewcode-block" id="BaseConfig.dump"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.dump">[docs]</a> + <span class="k">def</span> <span class="nf">dump</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Save BaseConfig to `path` as YAML file, or return as string."""</span> + <span class="n">config_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">as_dict</span><span class="p">()[</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">]</span> + <span class="n">yaml_</span> <span class="o">=</span> <span class="n">yaml</span><span class="o">.</span><span class="n">YAML</span><span class="p">(</span><span class="n">typ</span><span class="o">=</span><span class="s2">"safe"</span><span class="p">,</span> <span class="n">pure</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> + <span class="k">if</span> <span class="n">path</span><span class="p">:</span> + <span class="k">if</span> <span class="ow">not</span> <span class="n">path</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="n">CONFIG_FILES_SUFFIXES</span><span class="p">):</span> + <span class="n">path</span> <span class="o">+=</span> <span class="n">CONFIG_FILES_SUFFIXES</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> + <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">"w"</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> + <span class="n">yaml_</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">config_dict</span><span class="p">,</span> <span class="n">f</span><span class="p">)</span> + <span class="k">return</span> <span class="kc">None</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">return</span> <span class="n">yaml_</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">config_dict</span><span class="p">,</span> <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="p">)</span></div> + + +<div class="viewcode-block" id="BaseConfig.as_dict"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.as_dict">[docs]</a> + <span class="k">def</span> <span class="nf">as_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]:</span> +<span class="w"> </span><span class="sd">"""Represent BaseConfig as a dict.</span> + +<span class="sd"> This builds on `BaseModel.dict()` but can be overwritten.</span> +<span class="sd"> """</span> + <span class="k">return</span> <span class="p">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">dict</span><span class="p">()}</span></div> +</div> + + + +<div class="viewcode-block" id="get_all_argument_values"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.get_all_argument_values">[docs]</a> +<span class="k">def</span> <span class="nf">get_all_argument_values</span><span class="p">(</span> + <span class="n">fn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Return dict of all argument values to `fn`, including defaults."""</span> + <span class="c1"># Get all default argument values</span> + <span class="n">cfg</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">()</span> + <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">inspect</span><span class="o">.</span><span class="n">signature</span><span class="p">(</span><span class="n">fn</span><span class="p">)</span><span class="o">.</span><span class="n">parameters</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> + <span class="c1"># Don't save `self`, `*args`, or `**kwargs`</span> + <span class="k">if</span> <span class="n">key</span> <span class="o">==</span> <span class="s2">"self"</span> <span class="ow">or</span> <span class="n">param</span><span class="o">.</span><span class="n">kind</span> <span class="ow">in</span> <span class="p">[</span> + <span class="n">param</span><span class="o">.</span><span class="n">VAR_POSITIONAL</span><span class="p">,</span> + <span class="n">param</span><span class="o">.</span><span class="n">VAR_KEYWORD</span><span class="p">,</span> + <span class="p">]:</span> + <span class="k">continue</span> + <span class="n">cfg</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">default</span> + + <span class="c1"># Add positional arguments</span> + <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">cfg</span><span class="o">.</span><span class="n">keys</span><span class="p">(),</span> <span class="n">args</span><span class="p">):</span> + <span class="n">cfg</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">val</span> + + <span class="c1"># Add keyword arguments</span> + <span class="n">cfg</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">kwargs</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">cfg</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/utilities/config/configurable.html b/_modules/graphnet/utilities/config/configurable.html new file mode 100644 index 000000000..cfe023b65 --- /dev/null +++ b/_modules/graphnet/utilities/config/configurable.html @@ -0,0 +1,408 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.utilities.config.configurable — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/utilities/config/configurable" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.utilities.config.configurable </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-utilities-config-configurable--page-root">Source code for graphnet.utilities.config.configurable</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Bases for all configurable classes in `graphnet`."""</span> + +<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractclassmethod</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Union</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.config.base_config</span> <span class="kn">import</span> <span class="n">BaseConfig</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.decorators</span> <span class="kn">import</span> <span class="n">final</span> + + +<div class="viewcode-block" id="Configurable"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable">[docs]</a> +<span class="k">class</span> <span class="nc">Configurable</span><span class="p">(</span><span class="n">ABC</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Base class for all configurable classes in graphnet."""</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct `Configurable`."""</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_config</span><span class="p">:</span> <span class="n">BaseConfig</span> + + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> + + <span class="nd">@final</span> + <span class="nd">@property</span> + <span class="k">def</span> <span class="nf">config</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">BaseConfig</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return configuration to re-create the instance."""</span> + <span class="k">try</span><span class="p">:</span> + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_config</span> + <span class="k">except</span> <span class="ne">AttributeError</span><span class="p">:</span> + <span class="k">raise</span> <span class="ne">AttributeError</span><span class="p">(</span> + <span class="s2">"Config was not set. "</span> + <span class="s2">"Did you wrap the class constructor with `save_config`?"</span> + <span class="p">)</span> + +<div class="viewcode-block" id="Configurable.save_config"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable.save_config">[docs]</a> + <span class="nd">@final</span> + <span class="k">def</span> <span class="nf">save_config</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Save Config to `path` as YAML file."""</span> + <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">path</span><span class="p">)</span></div> + + +<div class="viewcode-block" id="Configurable.from_config"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable.from_config">[docs]</a> + <span class="nd">@abstractclassmethod</span> + <span class="k">def</span> <span class="nf">from_config</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">source</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">BaseConfig</span><span class="p">,</span> <span class="nb">str</span><span class="p">])</span> <span class="o">-></span> <span class="n">Any</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct instance from `source` configuration."""</span></div> +</div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/utilities/config/dataset_config.html b/_modules/graphnet/utilities/config/dataset_config.html new file mode 100644 index 000000000..91b338109 --- /dev/null +++ b/_modules/graphnet/utilities/config/dataset_config.html @@ -0,0 +1,585 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.utilities.config.dataset_config — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/utilities/config/dataset_config" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.utilities.config.dataset_config </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-utilities-config-dataset-config--page-root">Source code for graphnet.utilities.config.dataset_config</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Config classes for the `graphnet.data.dataset` module."""</span> + +<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">wraps</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">TYPE_CHECKING</span><span class="p">,</span> + <span class="n">Any</span><span class="p">,</span> + <span class="n">Callable</span><span class="p">,</span> + <span class="n">Dict</span><span class="p">,</span> + <span class="n">List</span><span class="p">,</span> + <span class="n">Optional</span><span class="p">,</span> + <span class="n">Union</span><span class="p">,</span> +<span class="p">)</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.config.base_config</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">BaseConfig</span><span class="p">,</span> + <span class="n">get_all_argument_values</span><span class="p">,</span> +<span class="p">)</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config.parsing</span> <span class="kn">import</span> <span class="n">traverse_and_apply</span> +<span class="kn">from</span> <span class="nn">.model_config</span> <span class="kn">import</span> <span class="n">ModelConfig</span> + +<span class="k">if</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span> + <span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span> + + +<span class="n">BACKEND_LOOKUP</span> <span class="o">=</span> <span class="p">{</span> + <span class="s2">"db"</span><span class="p">:</span> <span class="s2">"sqlite"</span><span class="p">,</span> + <span class="s2">"parquet"</span><span class="p">:</span> <span class="s2">"parquet"</span><span class="p">,</span> +<span class="p">}</span> + + +<div class="viewcode-block" id="DatasetConfig"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig">[docs]</a> +<span class="k">class</span> <span class="nc">DatasetConfig</span><span class="p">(</span><span class="n">BaseConfig</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Configuration for all `Dataset`s."""</span> + + <span class="c1"># Fields</span> + <span class="n">path</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> + <span class="n">pulsemaps</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> + <span class="n">features</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> + <span class="n">truth</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> + <span class="n">node_truth</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span> + <span class="n">index_column</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"event_no"</span> + <span class="n">truth_table</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"truth"</span> + <span class="n">node_truth_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> + <span class="n">string_selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span> + <span class="n">selection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span> + <span class="n">Union</span><span class="p">[</span> + <span class="nb">str</span><span class="p">,</span> + <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> + <span class="n">List</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]],</span> + <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]],</span> + <span class="p">]</span> + <span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> + <span class="n">loss_weight_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> + <span class="n">loss_weight_column</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> + <span class="n">loss_weight_default_value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> + <span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> + <span class="n">graph_definition</span><span class="p">:</span> <span class="n">Any</span> <span class="o">=</span> <span class="kc">None</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">data</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct `DataConfig`.</span> + +<span class="sd"> Can be used for dataset configuration as code, thereby making dataset</span> +<span class="sd"> construction more transparent and reproducible.</span> + +<span class="sd"> Examples:</span> +<span class="sd"> In one session, do:</span> + +<span class="sd"> >>> dataset = Dataset(...)</span> +<span class="sd"> >>> dataset.config.dump()</span> +<span class="sd"> path: (...)</span> +<span class="sd"> pulsemaps:</span> +<span class="sd"> - (...)</span> +<span class="sd"> (...)</span> +<span class="sd"> >>> dataset.config.dump("dataset.yml")</span> + +<span class="sd"> In another session, you can then do:</span> +<span class="sd"> >>> dataset = Dataset.from_config("dataset.yml")</span> + +<span class="sd"> # Uniquely for `DatasetConfig`, you can also define and load</span> +<span class="sd"> # multiple datasets</span> +<span class="sd"> >>> dataset.config.selection = {</span> +<span class="sd"> "train": "event_no % 2 == 0",</span> +<span class="sd"> "test": "event_no % 2 == 1",</span> +<span class="sd"> }</span> +<span class="sd"> >>> dataset.config.dump("dataset.yml")</span> +<span class="sd"> >>> datasets: Dict[str, Dataset] = Dataset.from_config(</span> +<span class="sd"> "dataset.yml"</span> +<span class="sd"> )</span> +<span class="sd"> >>> datasets</span> +<span class="sd"> {</span> +<span class="sd"> "train": Dataset(...),</span> +<span class="sd"> "test": Dataset(...),</span> +<span class="sd"> }</span> + +<span class="sd"> # You can also combine multiple selections into a single, named</span> +<span class="sd"> # dataset</span> +<span class="sd"> >>> dataset.config.selection = {</span> +<span class="sd"> "train": [</span> +<span class="sd"> "event_no % 2 == 0 & abs(pid) == 12",</span> +<span class="sd"> "event_no % 2 == 0 & abs(pid) == 14",</span> +<span class="sd"> "event_no % 2 == 0 & abs(pid) == 16",</span> +<span class="sd"> ],</span> +<span class="sd"> (...)</span> +<span class="sd"> }</span> +<span class="sd"> >>> dataset.config.dump("dataset.yml")</span> +<span class="sd"> >>> datasets: Dict[str, EnsembleDataset] = Dataset.from_config(</span> +<span class="sd"> "dataset.yml"</span> +<span class="sd"> )</span> +<span class="sd"> >>> datasets</span> +<span class="sd"> {</span> +<span class="sd"> "train": EnsembleDataset(...),</span> +<span class="sd"> (...)</span> +<span class="sd"> }</span> + +<span class="sd"> # Finally, you can still reference existing selection files in CSV</span> +<span class="sd"> # or JSON formats:</span> +<span class="sd"> >>> dataset.config.selection = {</span> +<span class="sd"> "train": "50000 random events ~ train_selection.csv",</span> +<span class="sd"> "test": "test_selection.csv",</span> +<span class="sd"> }</span> +<span class="sd"> """</span> + <span class="c1"># Single-key dictioaries are unpacked</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="s2">"selection"</span><span class="p">],</span> <span class="nb">dict</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="s2">"selection"</span><span class="p">])</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span> + <span class="n">data</span><span class="p">[</span><span class="s2">"selection"</span><span class="p">]</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="s2">"selection"</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="p">()))</span> + + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span> + + <span class="nd">@property</span> + <span class="k">def</span> <span class="nf">_backend</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> + <span class="n">path</span><span class="p">:</span> <span class="nb">str</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">path</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> + <span class="n">path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">path</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">path</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> + <span class="n">path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">path</span> + <span class="n">suffix</span> <span class="o">=</span> <span class="n">path</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"."</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> + <span class="k">try</span><span class="p">:</span> + <span class="k">return</span> <span class="n">BACKEND_LOOKUP</span><span class="p">[</span><span class="n">suffix</span><span class="p">]</span> + <span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span> + <span class="bp">self</span><span class="o">.</span><span class="n">error</span><span class="p">(</span> + <span class="sa">f</span><span class="s2">"Dataset at `path` </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">path</span><span class="si">}</span><span class="s2"> with suffix </span><span class="si">{</span><span class="n">suffix</span><span class="si">}</span><span class="s2"> not "</span> + <span class="s2">"supported."</span> + <span class="p">)</span> + <span class="k">raise</span> + + <span class="nd">@property</span> + <span class="k">def</span> <span class="nf">_dataset_class</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">type</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return the `Dataset` class implementation for this configuration."""</span> + <span class="kn">from</span> <span class="nn">graphnet.data.dataset.sqlite</span> <span class="kn">import</span> <span class="n">SQLiteDataset</span> + <span class="kn">from</span> <span class="nn">graphnet.data.dataset.parquet</span> <span class="kn">import</span> <span class="n">ParquetDataset</span> + + <span class="n">dataset_class</span> <span class="o">=</span> <span class="p">{</span> + <span class="s2">"sqlite"</span><span class="p">:</span> <span class="n">SQLiteDataset</span><span class="p">,</span> + <span class="s2">"parquet"</span><span class="p">:</span> <span class="n">ParquetDataset</span><span class="p">,</span> + <span class="p">}[</span><span class="bp">self</span><span class="o">.</span><span class="n">_backend</span><span class="p">]</span> + + <span class="k">return</span> <span class="n">dataset_class</span> + +<div class="viewcode-block" id="DatasetConfig.as_dict"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.as_dict">[docs]</a> + <span class="k">def</span> <span class="nf">as_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]:</span> +<span class="w"> </span><span class="sd">"""Represent ModelConfig as a dict.</span> + +<span class="sd"> This builds on `BaseModel.dict()` but wraps the output in a single-key</span> +<span class="sd"> dictionary to make it unambiguous to identify model arguments that are</span> +<span class="sd"> themselves models.</span> +<span class="sd"> """</span> + <span class="n">config_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dict</span><span class="p">()</span> + <span class="n">config_dict</span> <span class="o">=</span> <span class="n">traverse_and_apply</span><span class="p">(</span> + <span class="n">obj</span><span class="o">=</span><span class="nb">dict</span><span class="p">(</span><span class="o">**</span><span class="n">config_dict</span><span class="p">),</span> <span class="n">fn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_parse_torch</span> + <span class="p">)</span> + <span class="k">return</span> <span class="p">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">:</span> <span class="n">config_dict</span><span class="p">}</span></div> + + + <span class="k">def</span> <span class="nf">_parse_torch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">obj</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-></span> <span class="n">Any</span><span class="p">:</span> + <span class="kn">import</span> <span class="nn">torch</span> + + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">):</span> + <span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="fm">__str__</span><span class="p">()</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">return</span> <span class="n">obj</span></div> + + + +<div class="viewcode-block" id="save_dataset_config"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.save_dataset_config">[docs]</a> +<span class="k">def</span> <span class="nf">save_dataset_config</span><span class="p">(</span><span class="n">init_fn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">)</span> <span class="o">-></span> <span class="n">Callable</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Save the arguments to `__init__` functions as member `DatasetConfig`."""</span> + + <span class="k">def</span> <span class="nf">_replace_model_instance_with_config</span><span class="p">(</span> + <span class="n">obj</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s2">"Model"</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="n">ModelConfig</span><span class="p">,</span> <span class="n">Any</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Replace `Model` instances in `obj` with their `ModelConfig`."""</span> + <span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span> + <span class="kn">import</span> <span class="nn">torch</span> + + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">Model</span><span class="p">):</span> + <span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="n">config</span> + + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">):</span> + <span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="fm">__str__</span><span class="p">()</span> + + <span class="k">else</span><span class="p">:</span> + <span class="k">return</span> <span class="n">obj</span> + + <span class="nd">@wraps</span><span class="p">(</span><span class="n">init_fn</span><span class="p">)</span> + <span class="k">def</span> <span class="nf">wrapper</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-></span> <span class="n">Any</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Set `DatasetConfig` after calling `init_fn`."""</span> + <span class="c1"># Call wrapped method</span> + <span class="n">ret</span> <span class="o">=</span> <span class="n">init_fn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> + + <span class="c1"># Get all argument values, including defaults</span> + <span class="n">cfg</span> <span class="o">=</span> <span class="n">get_all_argument_values</span><span class="p">(</span><span class="n">init_fn</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> + + <span class="c1"># Handle nested `Model`s, etc.</span> + <span class="n">cfg</span> <span class="o">=</span> <span class="n">traverse_and_apply</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">_replace_model_instance_with_config</span><span class="p">)</span> + <span class="c1"># Add `DatasetConfig` as member variables</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_config</span> <span class="o">=</span> <span class="n">DatasetConfig</span><span class="p">(</span><span class="o">**</span><span class="n">cfg</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">ret</span> + + <span class="k">return</span> <span class="n">wrapper</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/utilities/config/model_config.html b/_modules/graphnet/utilities/config/model_config.html new file mode 100644 index 000000000..e63e73186 --- /dev/null +++ b/_modules/graphnet/utilities/config/model_config.html @@ -0,0 +1,654 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.utilities.config.model_config — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/utilities/config/model_config" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.utilities.config.model_config </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-utilities-config-model-config--page-root">Source code for graphnet.utilities.config.model_config</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Config classes for the `graphnet.models` module."""</span> +<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">wraps</span> +<span class="kn">import</span> <span class="nn">inspect</span> +<span class="kn">import</span> <span class="nn">re</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">TYPE_CHECKING</span><span class="p">,</span> + <span class="n">Any</span><span class="p">,</span> + <span class="n">Callable</span><span class="p">,</span> + <span class="n">Dict</span><span class="p">,</span> + <span class="n">List</span><span class="p">,</span> + <span class="n">Optional</span><span class="p">,</span> + <span class="n">Union</span><span class="p">,</span> +<span class="p">)</span> +<span class="kn">import</span> <span class="nn">torch</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.config.base_config</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">BaseConfig</span><span class="p">,</span> + <span class="n">get_all_argument_values</span><span class="p">,</span> +<span class="p">)</span> +<span class="kn">from</span> <span class="nn">graphnet.utilities.config.parsing</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">traverse_and_apply</span><span class="p">,</span> + <span class="n">get_all_grapnet_classes</span><span class="p">,</span> +<span class="p">)</span> + +<span class="k">if</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span> + <span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span> + + +<span class="n">FUNCTION_DEFINITION_PATTERN</span> <span class="o">=</span> <span class="p">(</span> + <span class="sa">r</span><span class="s2">"^def (?P<function_name>[a-zA-Z]</span><span class="si">{1}</span><span class="s2">[a-zA-Z0-9_]+) *\(.*\) *:"</span> +<span class="p">)</span> + + +<div class="viewcode-block" id="ModelConfig"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig">[docs]</a> +<span class="k">class</span> <span class="nc">ModelConfig</span><span class="p">(</span><span class="n">BaseConfig</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Configuration for all `Model`s."""</span> + + <span class="c1"># Fields</span> + <span class="n">class_name</span><span class="p">:</span> <span class="nb">str</span> + <span class="n">arguments</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> + + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">data</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct `ModelConfig`.</span> + +<span class="sd"> Can be used for model configuration as code, thereby making model</span> +<span class="sd"> construction more transparent and reproducible. Note that this does</span> +<span class="sd"> *not* save any trainable weights, meaning this is only a configuration</span> +<span class="sd"> for the model's hyperparameters. Any model instantiated from a</span> +<span class="sd"> ModelConfig or file will be randomly initialised, and thus should be</span> +<span class="sd"> trained.</span> + +<span class="sd"> Examples:</span> +<span class="sd"> In one session, do:</span> + +<span class="sd"> >>> model = Model(...)</span> +<span class="sd"> >>> model.config.dump()</span> +<span class="sd"> arguments:</span> +<span class="sd"> - (...): (...)</span> +<span class="sd"> class_name: Model</span> +<span class="sd"> >>> model.config.dump("model.yml")</span> + +<span class="sd"> In another session, you can then do:</span> +<span class="sd"> >>> model = Model.from_config("model.yml")</span> +<span class="sd"> """</span> + <span class="c1"># Parse any nested `ModelConfig` arguments</span> + <span class="k">for</span> <span class="n">arg</span> <span class="ow">in</span> <span class="n">data</span><span class="p">[</span><span class="s2">"arguments"</span><span class="p">]:</span> + <span class="n">value</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s2">"arguments"</span><span class="p">][</span><span class="n">arg</span><span class="p">]</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">)):</span> + <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="n">elem</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">value</span><span class="p">):</span> + <span class="n">data</span><span class="p">[</span><span class="s2">"arguments"</span><span class="p">][</span><span class="n">arg</span><span class="p">][</span> + <span class="n">ix</span> + <span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_parse_if_model_config_entry</span><span class="p">(</span><span class="n">elem</span><span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">data</span><span class="p">[</span><span class="s2">"arguments"</span><span class="p">][</span><span class="n">arg</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_parse_if_model_config_entry</span><span class="p">(</span> + <span class="n">value</span> + <span class="p">)</span> + <span class="c1"># Base class constructor</span> + <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span> + + <span class="k">def</span> <span class="nf">_is_model_config_entry</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">entry</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Check whether dictionary entry is a `ModelConfig`."""</span> + <span class="k">return</span> <span class="p">(</span> + <span class="nb">isinstance</span><span class="p">(</span><span class="n">entry</span><span class="p">,</span> <span class="nb">dict</span><span class="p">)</span> + <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">entry</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span> + <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> <span class="ow">in</span> <span class="n">entry</span> + <span class="p">)</span> + + <span class="k">def</span> <span class="nf">_parse_if_model_config_entry</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> <span class="n">entry</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="s2">"ModelConfig"</span><span class="p">,</span> <span class="n">Any</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Parse dictionary entry to `ModelConfig`."""</span> + <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_is_model_config_entry</span><span class="p">(</span><span class="n">entry</span><span class="p">):</span> + <span class="n">config_dict</span> <span class="o">=</span> <span class="n">entry</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">]</span> + <span class="n">config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="p">(</span><span class="o">**</span><span class="n">config_dict</span><span class="p">)</span> + <span class="k">return</span> <span class="n">config</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">return</span> <span class="n">entry</span> + + <span class="k">def</span> <span class="nf">_construct_model</span><span class="p">(</span> + <span class="bp">self</span><span class="p">,</span> + <span class="n">trust</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> + <span class="n">load_modules</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> + <span class="p">)</span> <span class="o">-></span> <span class="s2">"Model"</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Construct `Model` instance from `self` configuration.</span> + +<span class="sd"> Used as the basis for `Model.from_config`.</span> +<span class="sd"> """</span> + <span class="c1"># Check(s)</span> + <span class="k">if</span> <span class="n">load_modules</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">load_modules</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"torch"</span><span class="p">]</span> + <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">load_modules</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> + + <span class="c1"># Load any additional modules into the global namespace</span> + <span class="k">for</span> <span class="n">module</span> <span class="ow">in</span> <span class="n">load_modules</span><span class="p">:</span> + <span class="k">assert</span> <span class="n">re</span><span class="o">.</span><span class="n">match</span><span class="p">(</span><span class="s2">"^[a-zA-Z_]+$"</span><span class="p">,</span> <span class="n">module</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> + <span class="k">if</span> <span class="n">module</span> <span class="ow">in</span> <span class="nb">globals</span><span class="p">():</span> + <span class="k">continue</span> + <span class="n">exec</span><span class="p">(</span><span class="sa">f</span><span class="s2">"import </span><span class="si">{</span><span class="n">module</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="nb">globals</span><span class="p">())</span> + + <span class="c1"># Get a lookup for all classes in `graphnet`</span> + <span class="kn">import</span> <span class="nn">graphnet.data</span> + <span class="kn">import</span> <span class="nn">graphnet.models</span> + <span class="kn">import</span> <span class="nn">graphnet.training</span> + + <span class="n">namespace_classes</span> <span class="o">=</span> <span class="n">get_all_grapnet_classes</span><span class="p">(</span> + <span class="n">graphnet</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">graphnet</span><span class="o">.</span><span class="n">models</span><span class="p">,</span> <span class="n">graphnet</span><span class="o">.</span><span class="n">training</span> + <span class="p">)</span> + + <span class="c1"># Parse potential ModelConfig arguments</span> + <span class="n">arguments</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">arguments</span><span class="p">)</span> + <span class="n">arguments</span> <span class="o">=</span> <span class="n">traverse_and_apply</span><span class="p">(</span> + <span class="n">arguments</span><span class="p">,</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_deserialise</span><span class="p">,</span> + <span class="n">fn_kwargs</span><span class="o">=</span><span class="p">{</span><span class="s2">"trust"</span><span class="p">:</span> <span class="n">trust</span><span class="p">},</span> + <span class="p">)</span> + + <span class="c1"># Construct model based on arguments</span> + <span class="k">return</span> <span class="n">namespace_classes</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">class_name</span><span class="p">](</span><span class="o">**</span><span class="n">arguments</span><span class="p">)</span> + + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">_deserialise</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">obj</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">trust</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Any</span><span class="p">:</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">ModelConfig</span><span class="p">):</span> + <span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span> + + <span class="k">return</span> <span class="n">Model</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">trust</span><span class="o">=</span><span class="n">trust</span><span class="p">)</span> + + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> <span class="ow">and</span> <span class="n">obj</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"!lambda"</span><span class="p">):</span> + <span class="k">if</span> <span class="n">trust</span><span class="p">:</span> + <span class="n">source</span> <span class="o">=</span> <span class="n">obj</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> + <span class="n">f</span> <span class="o">=</span> <span class="nb">eval</span><span class="p">(</span><span class="n">source</span><span class="p">)</span> + + <span class="c1"># Save a copy of the source code attached to the callable,</span> + <span class="c1"># since the `inspect` module is not able to get the source code</span> + <span class="c1"># for functions that are not defined on file.</span> + <span class="c1"># See `self._serialise`.</span> + <span class="n">f</span><span class="o">.</span><span class="n">_source</span> <span class="o">=</span> <span class="n">source</span> + <span class="k">return</span> <span class="n">f</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> + <span class="s2">"Constructing model containing a lambda function "</span> + <span class="sa">f</span><span class="s2">"(</span><span class="si">{</span><span class="n">obj</span><span class="si">}</span><span class="s2">) with `trust=False`. If you trust the lambda "</span> + <span class="s2">"functions in this ModelConfig, set `trust=True` and "</span> + <span class="s2">"reconstruct the model again."</span> + <span class="p">)</span> + + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> <span class="ow">and</span> <span class="n">obj</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"!function"</span><span class="p">):</span> + <span class="k">if</span> <span class="n">trust</span><span class="p">:</span> + <span class="n">source</span> <span class="o">=</span> <span class="n">obj</span><span class="p">[</span><span class="mi">10</span><span class="p">:]</span> + <span class="n">match</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">match</span><span class="p">(</span><span class="n">FUNCTION_DEFINITION_PATTERN</span><span class="p">,</span> <span class="n">source</span><span class="p">)</span> + <span class="k">assert</span> <span class="n">match</span> + <span class="n">exec</span><span class="p">(</span><span class="n">source</span><span class="p">)</span> + <span class="n">fn</span> <span class="o">=</span> <span class="nb">eval</span><span class="p">(</span><span class="n">match</span><span class="o">.</span><span class="n">group</span><span class="p">(</span><span class="s2">"function_name"</span><span class="p">))</span> + <span class="k">return</span> <span class="n">fn</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> + <span class="sa">f</span><span class="s2">"Constructing model containing a function (</span><span class="si">{</span><span class="n">obj</span><span class="si">}</span><span class="s2">) with "</span> + <span class="s2">"`trust=False`. If you trust the functions in this "</span> + <span class="s2">"ModelConfig, set `trust=True` and reconstruct the model "</span> + <span class="s2">"again."</span> + <span class="p">)</span> + + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> <span class="ow">and</span> <span class="n">obj</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"!class"</span><span class="p">):</span> + <span class="k">if</span> <span class="n">trust</span><span class="p">:</span> + <span class="n">module</span><span class="p">,</span> <span class="n">class_name</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">split</span><span class="p">()[</span><span class="mi">1</span><span class="p">:]</span> + <span class="n">exec</span><span class="p">(</span><span class="sa">f</span><span class="s2">"from </span><span class="si">{</span><span class="n">module</span><span class="si">}</span><span class="s2"> import </span><span class="si">{</span><span class="n">class_name</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> + <span class="k">return</span> <span class="nb">eval</span><span class="p">(</span><span class="n">class_name</span><span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> + <span class="sa">f</span><span class="s2">"Constructing model containing a class (</span><span class="si">{</span><span class="n">obj</span><span class="si">}</span><span class="s2">) with "</span> + <span class="s2">"`trust=False`. If you trust the class definitions in "</span> + <span class="s2">"this ModelConfig, set `trust=True` and reconstruct the "</span> + <span class="s2">"model again."</span> + <span class="p">)</span> + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> <span class="ow">and</span> <span class="n">obj</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"torch"</span><span class="p">):</span> + <span class="k">return</span> <span class="nb">eval</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span> + + <span class="k">else</span><span class="p">:</span> + <span class="k">return</span> <span class="n">obj</span> + + <span class="nd">@classmethod</span> + <span class="k">def</span> <span class="nf">_serialise</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">obj</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-></span> <span class="n">Any</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Serialise `obj` to a format that can be saved to file."""</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">ModelConfig</span><span class="p">):</span> + <span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="n">as_dict</span><span class="p">()</span> + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">type</span><span class="p">):</span> + <span class="k">return</span> <span class="sa">f</span><span class="s2">"!class </span><span class="si">{</span><span class="n">obj</span><span class="o">.</span><span class="vm">__module__</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">obj</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">"</span> + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">):</span> + <span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="fm">__str__</span><span class="p">()</span> + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">Callable</span><span class="p">):</span> <span class="c1"># type: ignore[arg-type]</span> + <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="s2">"__name__"</span><span class="p">)</span> <span class="ow">and</span> <span class="n">obj</span><span class="o">.</span><span class="vm">__name__</span> <span class="o">==</span> <span class="s2">"<lambda>"</span><span class="p">:</span> + <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="s2">"_source"</span><span class="p">):</span> + <span class="c1"># If source code is set manually during deserialisation.</span> + <span class="c1"># See `self._deserialise`.</span> + <span class="n">source</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">_source</span> + <span class="k">else</span><span class="p">:</span> + <span class="n">source</span> <span class="o">=</span> <span class="n">inspect</span><span class="o">.</span><span class="n">getsource</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"="</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">strip</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2"> ,"</span><span class="p">)</span> + + <span class="k">return</span> <span class="s2">"!"</span> <span class="o">+</span> <span class="n">source</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">try</span><span class="p">:</span> + <span class="n">source</span> <span class="o">=</span> <span class="n">inspect</span><span class="o">.</span><span class="n">getsource</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span> + <span class="n">match</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">match</span><span class="p">(</span><span class="n">FUNCTION_DEFINITION_PATTERN</span><span class="p">,</span> <span class="n">source</span><span class="p">)</span> + <span class="k">if</span> <span class="n">match</span> <span class="ow">and</span> <span class="n">match</span><span class="o">.</span><span class="n">group</span><span class="p">(</span><span class="s2">"function_name"</span><span class="p">):</span> + <span class="k">return</span> <span class="sa">f</span><span class="s2">"!function </span><span class="si">{</span><span class="n">source</span><span class="si">}</span><span class="s2">"</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">raise</span> <span class="ne">ValueError</span> + <span class="k">except</span> <span class="p">(</span><span class="ne">TypeError</span><span class="p">,</span> <span class="ne">ValueError</span><span class="p">):</span> + <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> + <span class="sa">f</span><span class="s2">"Object `</span><span class="si">{</span><span class="n">obj</span><span class="si">}</span><span class="s2">` is callable but not a lambda or "</span> + <span class="s2">"regular function. Please wrap in a, e.g., lambda "</span> + <span class="s2">"function to allow for saving this function verbatim "</span> + <span class="s2">"in a model config file."</span> + <span class="p">)</span> + + <span class="k">return</span> <span class="n">obj</span> + +<div class="viewcode-block" id="ModelConfig.as_dict"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig.as_dict">[docs]</a> + <span class="k">def</span> <span class="nf">as_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]:</span> +<span class="w"> </span><span class="sd">"""Represent ModelConfig as a dict.</span> + +<span class="sd"> This builds on `BaseModel.dict()` but wraps the output in a single-key</span> +<span class="sd"> dictionary to make it unambiguous to identify model arguments that are</span> +<span class="sd"> themselves models.</span> +<span class="sd"> """</span> + <span class="n">config_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dict</span><span class="p">()</span> + <span class="n">config_dict</span><span class="p">[</span><span class="s2">"arguments"</span><span class="p">]</span> <span class="o">=</span> <span class="n">traverse_and_apply</span><span class="p">(</span> + <span class="bp">self</span><span class="o">.</span><span class="n">arguments</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_serialise</span> + <span class="p">)</span> + + <span class="k">return</span> <span class="p">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">:</span> <span class="n">config_dict</span><span class="p">}</span></div> +</div> + + + +<div class="viewcode-block" id="save_model_config"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.save_model_config">[docs]</a> +<span class="k">def</span> <span class="nf">save_model_config</span><span class="p">(</span><span class="n">init_fn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">)</span> <span class="o">-></span> <span class="n">Callable</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Save the arguments to `__init__` functions as a member `ModelConfig`."""</span> + + <span class="k">def</span> <span class="nf">_replace_model_instance_with_config</span><span class="p">(</span> + <span class="n">obj</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s2">"Model"</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> + <span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="n">ModelConfig</span><span class="p">,</span> <span class="n">Any</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Replace `Model` instances in `obj` with their `ModelConfig`."""</span> + <span class="kn">from</span> <span class="nn">graphnet.models</span> <span class="kn">import</span> <span class="n">Model</span> + + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">Model</span><span class="p">):</span> + <span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="n">config</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">return</span> <span class="n">obj</span> + + <span class="nd">@wraps</span><span class="p">(</span><span class="n">init_fn</span><span class="p">)</span> + <span class="k">def</span> <span class="nf">wrapper</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-></span> <span class="n">Any</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Set `ModelConfig` after calling `init_fn`."""</span> + <span class="c1"># Call wrapped method</span> + <span class="n">ret</span> <span class="o">=</span> <span class="n">init_fn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> + + <span class="c1"># Get all argument values, including defaults</span> + <span class="n">cfg</span> <span class="o">=</span> <span class="n">get_all_argument_values</span><span class="p">(</span><span class="n">init_fn</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> + + <span class="c1"># Handle nested `Model`s, etc.</span> + <span class="n">cfg</span> <span class="o">=</span> <span class="n">traverse_and_apply</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">_replace_model_instance_with_config</span><span class="p">)</span> + + <span class="c1"># Add `ModelConfig` as member variables</span> + <span class="bp">self</span><span class="o">.</span><span class="n">_config</span> <span class="o">=</span> <span class="n">ModelConfig</span><span class="p">(</span> + <span class="n">class_name</span><span class="o">=</span><span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">),</span> + <span class="n">arguments</span><span class="o">=</span><span class="nb">dict</span><span class="p">(</span><span class="o">**</span><span class="n">cfg</span><span class="p">),</span> + <span class="p">)</span> + + <span class="k">return</span> <span class="n">ret</span> + + <span class="k">return</span> <span class="n">wrapper</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/utilities/config/parsing.html b/_modules/graphnet/utilities/config/parsing.html new file mode 100644 index 000000000..e1bff4682 --- /dev/null +++ b/_modules/graphnet/utilities/config/parsing.html @@ -0,0 +1,475 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.utilities.config.parsing — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/utilities/config/parsing" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.utilities.config.parsing </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-utilities-config-parsing--page-root">Source code for graphnet.utilities.config.parsing</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Utility functions for parsing for using with Config-classes."""</span> + +<span class="kn">import</span> <span class="nn">itertools</span> +<span class="kn">import</span> <span class="nn">pkgutil</span> +<span class="kn">import</span> <span class="nn">types</span> +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="p">(</span> + <span class="n">Any</span><span class="p">,</span> + <span class="n">Callable</span><span class="p">,</span> + <span class="n">Dict</span><span class="p">,</span> + <span class="n">List</span><span class="p">,</span> + <span class="n">Optional</span><span class="p">,</span> +<span class="p">)</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.logging</span> <span class="kn">import</span> <span class="n">Logger</span> + + +<div class="viewcode-block" id="traverse_and_apply"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.traverse_and_apply">[docs]</a> +<span class="k">def</span> <span class="nf">traverse_and_apply</span><span class="p">(</span> + <span class="n">obj</span><span class="p">:</span> <span class="n">Any</span><span class="p">,</span> <span class="n">fn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">fn_kwargs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span> +<span class="p">)</span> <span class="o">-></span> <span class="n">Any</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Apply `fn` to all elements in `obj`, resulting in same structure."""</span> + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span> + <span class="k">return</span> <span class="p">[</span><span class="n">traverse_and_apply</span><span class="p">(</span><span class="n">elem</span><span class="p">,</span> <span class="n">fn</span><span class="p">,</span> <span class="n">fn_kwargs</span><span class="p">)</span> <span class="k">for</span> <span class="n">elem</span> <span class="ow">in</span> <span class="n">obj</span><span class="p">]</span> + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span> + <span class="k">return</span> <span class="p">{</span> + <span class="n">key</span><span class="p">:</span> <span class="n">traverse_and_apply</span><span class="p">(</span><span class="n">val</span><span class="p">,</span> <span class="n">fn</span><span class="p">,</span> <span class="n">fn_kwargs</span><span class="p">)</span> + <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">obj</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> + <span class="p">}</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">if</span> <span class="n">fn_kwargs</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> + <span class="n">fn_kwargs</span> <span class="o">=</span> <span class="p">{}</span> + <span class="k">return</span> <span class="n">fn</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="o">**</span><span class="n">fn_kwargs</span><span class="p">)</span></div> + + + +<div class="viewcode-block" id="list_all_submodules"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.list_all_submodules">[docs]</a> +<span class="k">def</span> <span class="nf">list_all_submodules</span><span class="p">(</span><span class="o">*</span><span class="n">packages</span><span class="p">:</span> <span class="n">types</span><span class="o">.</span><span class="n">ModuleType</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">types</span><span class="o">.</span><span class="n">ModuleType</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""List all submodules in `packages` recursively."""</span> + <span class="c1"># Resolve one or more packages</span> + <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">packages</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span> + <span class="k">return</span> <span class="nb">list</span><span class="p">(</span> + <span class="n">itertools</span><span class="o">.</span><span class="n">chain</span><span class="o">.</span><span class="n">from_iterable</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="n">list_all_submodules</span><span class="p">,</span> <span class="n">packages</span><span class="p">))</span> + <span class="p">)</span> + <span class="k">else</span><span class="p">:</span> + <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">packages</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"No packages specified"</span> + <span class="n">package</span> <span class="o">=</span> <span class="n">packages</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> + + <span class="n">submodules</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">types</span><span class="o">.</span><span class="n">ModuleType</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span> + <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">module_name</span><span class="p">,</span> <span class="n">is_pkg</span> <span class="ow">in</span> <span class="n">pkgutil</span><span class="o">.</span><span class="n">walk_packages</span><span class="p">(</span> + <span class="n">package</span><span class="o">.</span><span class="n">__path__</span><span class="p">,</span> <span class="n">package</span><span class="o">.</span><span class="vm">__name__</span> <span class="o">+</span> <span class="s2">"."</span> + <span class="p">):</span> + <span class="n">module</span> <span class="o">=</span> <span class="nb">__import__</span><span class="p">(</span><span class="n">module_name</span><span class="p">,</span> <span class="n">fromlist</span><span class="o">=</span><span class="s2">"dummylist"</span><span class="p">)</span> + <span class="n">submodules</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">module</span><span class="p">)</span> + <span class="k">if</span> <span class="n">is_pkg</span><span class="p">:</span> + <span class="n">submodules</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">list_all_submodules</span><span class="p">(</span><span class="n">module</span><span class="p">))</span> + + <span class="k">return</span> <span class="n">submodules</span></div> + + + +<div class="viewcode-block" id="get_all_grapnet_classes"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.get_all_grapnet_classes">[docs]</a> +<span class="k">def</span> <span class="nf">get_all_grapnet_classes</span><span class="p">(</span><span class="o">*</span><span class="n">packages</span><span class="p">:</span> <span class="n">types</span><span class="o">.</span><span class="n">ModuleType</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">type</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""List all grapnet classes in `packages`."""</span> + <span class="n">submodules</span> <span class="o">=</span> <span class="n">list_all_submodules</span><span class="p">(</span><span class="o">*</span><span class="n">packages</span><span class="p">)</span> + <span class="n">classes</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">type</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span> + <span class="k">for</span> <span class="n">submodule</span> <span class="ow">in</span> <span class="n">submodules</span><span class="p">:</span> + <span class="n">new_classes</span> <span class="o">=</span> <span class="n">get_graphnet_classes</span><span class="p">(</span><span class="n">submodule</span><span class="p">)</span> + <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">new_classes</span><span class="p">:</span> + <span class="k">if</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">classes</span> <span class="ow">and</span> <span class="n">classes</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">!=</span> <span class="n">new_classes</span><span class="p">[</span><span class="n">key</span><span class="p">]:</span> + <span class="n">Logger</span><span class="p">()</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span> + <span class="sa">f</span><span class="s2">"Class </span><span class="si">{</span><span class="n">key</span><span class="si">}</span><span class="s2"> found in both </span><span class="si">{</span><span class="n">classes</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="si">}</span><span class="s2"> and "</span> + <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">new_classes</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="si">}</span><span class="s2">. Keeping first instance. "</span> + <span class="s2">"Consider renaming."</span> + <span class="p">)</span> + <span class="n">classes</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">new_classes</span><span class="p">)</span> + + <span class="k">return</span> <span class="n">classes</span></div> + + + +<div class="viewcode-block" id="is_graphnet_module"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.is_graphnet_module">[docs]</a> +<span class="k">def</span> <span class="nf">is_graphnet_module</span><span class="p">(</span><span class="n">obj</span><span class="p">:</span> <span class="n">types</span><span class="o">.</span><span class="n">ModuleType</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return whether `obj` is a module in graphnet."""</span> + <span class="k">return</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">types</span><span class="o">.</span><span class="n">ModuleType</span><span class="p">)</span> <span class="ow">and</span> <span class="n">obj</span><span class="o">.</span><span class="vm">__name__</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span> + <span class="s2">"graphnet."</span> + <span class="p">)</span></div> + + + +<div class="viewcode-block" id="is_graphnet_class"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.is_graphnet_class">[docs]</a> +<span class="k">def</span> <span class="nf">is_graphnet_class</span><span class="p">(</span><span class="n">obj</span><span class="p">:</span> <span class="nb">type</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return whether `obj` is a class in graphnet."""</span> + <span class="k">return</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">type</span><span class="p">)</span> <span class="ow">and</span> <span class="n">obj</span><span class="o">.</span><span class="vm">__module__</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"graphnet."</span><span class="p">)</span></div> + + + +<div class="viewcode-block" id="get_graphnet_classes"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.get_graphnet_classes">[docs]</a> +<span class="k">def</span> <span class="nf">get_graphnet_classes</span><span class="p">(</span><span class="n">module</span><span class="p">:</span> <span class="n">types</span><span class="o">.</span><span class="n">ModuleType</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">type</span><span class="p">]:</span> +<span class="w"> </span><span class="sd">"""Return a lookup of all graphnet class names in `module`."""</span> + <span class="k">if</span> <span class="ow">not</span> <span class="n">is_graphnet_module</span><span class="p">(</span><span class="n">module</span><span class="p">):</span> + <span class="n">Logger</span><span class="p">()</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">module</span><span class="si">}</span><span class="s2"> is not a graphnet module"</span><span class="p">)</span> + <span class="k">return</span> <span class="p">{}</span> + <span class="n">classes</span> <span class="o">=</span> <span class="p">{</span> + <span class="n">key</span><span class="p">:</span> <span class="n">val</span> + <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> + <span class="k">if</span> <span class="n">is_graphnet_class</span><span class="p">(</span><span class="n">val</span><span class="p">)</span> + <span class="p">}</span> + <span class="k">return</span> <span class="n">classes</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/utilities/config/training_config.html b/_modules/graphnet/utilities/config/training_config.html new file mode 100644 index 000000000..add0468a2 --- /dev/null +++ b/_modules/graphnet/utilities/config/training_config.html @@ -0,0 +1,378 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.utilities.config.training_config — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../../_static/material.css?v=79c92029" /> + <script src="../../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../../about.html" /> + <link rel="index" title="Index" href="../../../../genindex.html" /> + <link rel="search" title="Search" href="../../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/utilities/config/training_config" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.utilities.config.training_config </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../../"versions.json"", + target_loc = "../../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-utilities-config-training-config--page-root">Source code for graphnet.utilities.config.training_config</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Config classes for the `graphnet.training` module."""</span> + +<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Union</span> + +<span class="kn">from</span> <span class="nn">graphnet.utilities.config.base_config</span> <span class="kn">import</span> <span class="n">BaseConfig</span> + + +<div class="viewcode-block" id="TrainingConfig"> +<a class="viewcode-back" href="../../../../api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig">[docs]</a> +<span class="k">class</span> <span class="nc">TrainingConfig</span><span class="p">(</span><span class="n">BaseConfig</span><span class="p">):</span> +<span class="w"> </span><span class="sd">"""Configuration for all trainings."""</span> + + <span class="c1"># Fields</span> + <span class="n">target</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> + <span class="n">early_stopping_patience</span><span class="p">:</span> <span class="nb">int</span> + <span class="n">fit</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> + <span class="n">dataloader</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/graphnet/utilities/filesys.html b/_modules/graphnet/utilities/filesys.html index ef2600ad3..121c9bfe5 100644 --- a/_modules/graphnet/utilities/filesys.html +++ b/_modules/graphnet/utilities/filesys.html @@ -447,7 +447,7 @@ <h1 id="modules-graphnet-utilities-filesys--page-root">Source code for graphnet. </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/utilities/imports.html b/_modules/graphnet/utilities/imports.html index 40699ad6e..80debdadf 100644 --- a/_modules/graphnet/utilities/imports.html +++ b/_modules/graphnet/utilities/imports.html @@ -420,7 +420,7 @@ <h1 id="modules-graphnet-utilities-imports--page-root">Source code for graphnet. </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/utilities/logging.html b/_modules/graphnet/utilities/logging.html index cb014897d..aa1b8a091 100644 --- a/_modules/graphnet/utilities/logging.html +++ b/_modules/graphnet/utilities/logging.html @@ -630,7 +630,7 @@ <h1 id="modules-graphnet-utilities-logging--page-root">Source code for graphnet. </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/_modules/graphnet/utilities/maths.html b/_modules/graphnet/utilities/maths.html new file mode 100644 index 000000000..607211db6 --- /dev/null +++ b/_modules/graphnet/utilities/maths.html @@ -0,0 +1,371 @@ +<!DOCTYPE html> + +<html lang="en" data-content_root="../../../"> + <head> + <meta charset="utf-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + + + + <meta name="viewport" content="width=device-width,initial-scale=1"> + <meta http-equiv="x-ua-compatible" content="ie=edge"> + <meta name="lang:clipboard.copy" content="Copy to clipboard"> + <meta name="lang:clipboard.copied" content="Copied to clipboard"> + <meta name="lang:search.language" content="en"> + <meta name="lang:search.pipeline.stopwords" content="True"> + <meta name="lang:search.pipeline.trimmer" content="True"> + <meta name="lang:search.result.none" content="No matching documents"> + <meta name="lang:search.result.one" content="1 matching document"> + <meta name="lang:search.result.other" content="# matching documents"> + <meta name="lang:search.tokenizer" content="[\s\-]+"> + + + <link href="https://fonts.gstatic.com/" rel="preconnect" crossorigin> + <link href="https://fonts.googleapis.com/css?family=Roboto+Mono:400,500,700|Roboto:300,400,400i,700&display=fallback" rel="stylesheet"> + + <style> + body, + input { + font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, sans-serif + } + + code, + kbd, + pre { + font-family: "Roboto Mono", "Courier New", Courier, monospace + } + </style> + + + <link rel="stylesheet" href="../../../_static/stylesheets/application.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-palette.css"/> + <link rel="stylesheet" href="../../../_static/stylesheets/application-fixes.css"/> + + <link rel="stylesheet" href="../../../_static/fonts/material-icons.css"/> + + <meta name="theme-color" content="#3f51b5"> + <script src="../../../_static/javascripts/modernizr.js"></script> + +<script async src="https://www.googletagmanager.com/gtag/js?id=UA-XXXXX"></script> +<script> + window.dataLayer = window.dataLayer || []; + + function gtag() { + dataLayer.push(arguments); + } + + gtag('js', new Date()); + + gtag('config', 'UA-XXXXX'); +</script> + + + <title>graphnet.utilities.maths — graphnet documentation</title> + +<style> + dt:target { + margin-top: 0; + padding-top: 0; + } + + + /* + .sig-prename { + display: none; + } + */ + + .py.class .sig-name, + .py.function .sig-name, + .py.method .sig-name, + .py.exception .sig-name { + color: #37474f; + font-feature-settings: "kern"; + font-family: "Roboto Mono", "Courier New", Courier, monospace; + font-weight: 700; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.method .sig-object, + .py.exception .sig-object { + padding: 1ex; + } + + .py.class .sig-object, + .py.function .sig-object, + .py.exception .sig-object { + border-top: 1px solid gray; + } + .py.method .sig-object { + border-top: 1px solid lightgray; + } + + .py.class .sig-object, + .py.exception .sig-object { + background: rgba(0,0,0,0.06); + } + .py.function .sig-object, + .py.method .sig-object { + background: rgba(0,0,0,0.03); + } + + #eu-emblem { + margin: 0; + } + + #eu-emblem figcaption { + display: none; + } + +</style> + <link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=83e35b93" /> + <link rel="stylesheet" type="text/css" href="../../../_static/material.css?v=79c92029" /> + <script src="../../../_static/documentation_options.js?v=5929fcd5"></script> + <script src="../../../_static/doctools.js?v=888ff710"></script> + <script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script> + <link rel="icon" href="../../../_static/favicon.svg"/> + <link rel="author" title="About these documents" href="../../../about.html" /> + <link rel="index" title="Index" href="../../../genindex.html" /> + <link rel="search" title="Search" href="../../../search.html" /> + + + + </head> + <body dir=ltr + data-md-color-primary=indigo data-md-color-accent=blue> + + <svg class="md-svg"> + <defs data-children-count="0"> + + <svg xmlns="http://www.w3.org/2000/svg" width="416" height="448" viewBox="0 0 416 448" id="__github"><path fill="currentColor" d="M160 304q0 10-3.125 20.5t-10.75 19T128 352t-18.125-8.5-10.75-19T96 304t3.125-20.5 10.75-19T128 256t18.125 8.5 10.75 19T160 304zm160 0q0 10-3.125 20.5t-10.75 19T288 352t-18.125-8.5-10.75-19T256 304t3.125-20.5 10.75-19T288 256t18.125 8.5 10.75 19T320 304zm40 0q0-30-17.25-51T296 232q-10.25 0-48.75 5.25Q229.5 240 208 240t-39.25-2.75Q130.75 232 120 232q-29.5 0-46.75 21T56 304q0 22 8 38.375t20.25 25.75 30.5 15 35 7.375 37.25 1.75h42q20.5 0 37.25-1.75t35-7.375 30.5-15 20.25-25.75T360 304zm56-44q0 51.75-15.25 82.75-9.5 19.25-26.375 33.25t-35.25 21.5-42.5 11.875-42.875 5.5T212 416q-19.5 0-35.5-.75t-36.875-3.125-38.125-7.5-34.25-12.875T37 371.5t-21.5-28.75Q0 312 0 260q0-59.25 34-99-6.75-20.5-6.75-42.5 0-29 12.75-54.5 27 0 47.5 9.875t47.25 30.875Q171.5 96 212 96q37 0 70 8 26.25-20.5 46.75-30.25T376 64q12.75 25.5 12.75 54.5 0 21.75-6.75 42 34 40 34 99.5z"/></svg> + + </defs> + </svg> + + <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer"> + <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search"> + <label class="md-overlay" data-md-component="overlay" for="__drawer"></label> + <a href="#_modules/graphnet/utilities/maths" tabindex="1" class="md-skip"> Skip to content </a> + <header class="md-header" data-md-component="header"> + <nav class="md-header-nav md-grid"> + <div class="md-flex navheader"> + <div class="md-flex__cell md-flex__cell--shrink"> + <a href="../../../index.html" title="graphnet documentation" + class="md-header-nav__button md-logo"> + + + + </a> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--menu md-header-nav__button" for="__drawer"></label> + </div> + <div class="md-flex__cell md-flex__cell--stretch"> + <div class="md-flex__ellipsis md-header-nav__title" data-md-component="title"> + <span class="md-header-nav__topic">GraphNeT</span> + <span class="md-header-nav__topic"> graphnet.utilities.maths </span> + </div> + </div> + <div class="md-flex__cell md-flex__cell--shrink"> + <label class="md-icon md-icon--search md-header-nav__button" for="__search"></label> + +<div class="md-search" data-md-component="search" role="dialog"> + <label class="md-search__overlay" for="__search"></label> + <div class="md-search__inner" role="search"> + <form class="md-search__form" action="../../../search.html" method="get" name="search"> + <input type="text" class="md-search__input" name="q" placeholder=""Search"" + autocapitalize="off" autocomplete="off" spellcheck="false" + data-md-component="query" data-md-state="active"> + <label class="md-icon md-search__icon" for="__search"></label> + <button type="reset" class="md-icon md-search__icon" data-md-component="reset" tabindex="-1"> +  + </button> + </form> + <div class="md-search__output"> + <div class="md-search__scrollwrap" data-md-scrollfix> + <div class="md-search-result" data-md-component="result"> + <div class="md-search-result__meta"> + Type to start searching + </div> + <ol class="md-search-result__list"></ol> + </div> + </div> + </div> + </div> +</div> + + </div> + + <div class="md-flex__cell md-flex__cell--shrink"> + <div class="md-header-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + </div> + + + + <script src="../../../_static/javascripts/version_dropdown.js"></script> + <script> + var json_loc = "../../../"versions.json"", + target_loc = "../../../../", + text = "Versions"; + $( document ).ready( add_version_dropdown(json_loc, target_loc, text)); + </script> + + + </div> + </nav> +</header> + + + <div class="md-container"> + + + + <nav class="md-tabs" data-md-component="tabs"> + <div class="md-tabs__inner md-grid"> + <ul class="md-tabs__list"> + + <li class="md-tabs__item"><a href="../../../index.html" class="md-tabs__link">Documentation</a></li> + <li class="md-tabs__item"><a href="../../index.html" class="md-tabs__link">Module code</a></li> + </ul> + </div> + </nav> + <main class="md-main"> + <div class="md-main__inner md-grid" data-md-component="container"> + + <div class="md-sidebar md-sidebar--primary" data-md-component="navigation"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + <nav class="md-nav md-nav--primary" data-md-level="0"> + <label class="md-nav__title md-nav__title--site" for="__drawer"> + <a href="../../../index.html" title="graphnet documentation" class="md-nav__button md-logo"> + + <img src="../../../_static/" alt=" logo" width="48" height="48"> + + </a> + <a href="../../../index.html" + title="graphnet documentation">GraphNeT</a> + </label> + <div class="md-nav__source"> + <a href="https://github.com/graphnet-team/graphnet/" title="Go to repository" class="md-source" data-md-source="github"> + + <div class="md-source__icon"> + <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 0 24 24" width="28" height="28"> + <use xlink:href="#__github" width="24" height="24"></use> + </svg> + </div> + + <div class="md-source__repository"> + GraphNeT + </div> +</a> + </div> + + + + + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="../../../install.html" class="md-nav__link">Install</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../contribute.html" class="md-nav__link">Contribute</a> + + + </li> + <li class="md-nav__item"> + + + <a href="../../../api/graphnet.html" class="md-nav__link">API</a> + + + </li> + </ul> + + +</nav> + </div> + </div> + </div> + <div class="md-sidebar md-sidebar--secondary" data-md-component="toc"> + <div class="md-sidebar__scrollwrap"> + <div class="md-sidebar__inner"> + +<nav class="md-nav md-nav--secondary"> + <ul class="md-nav__list" data-md-scrollfix=""> + </ul> +</nav> + </div> + </div> + </div> + + <div class="md-content"> + <article class="md-content__inner md-typeset" role="main"> + + <h1 id="modules-graphnet-utilities-maths--page-root">Source code for graphnet.utilities.maths</h1><div class="highlight"><pre> +<span></span><span class="sd">"""Collection of assorted "maths-like" functions."""</span> + +<span class="kn">import</span> <span class="nn">torch</span> + + +<div class="viewcode-block" id="eps_like"> +<a class="viewcode-back" href="../../../api/graphnet.utilities.maths.html#graphnet.utilities.maths.eps_like">[docs]</a> +<span class="k">def</span> <span class="nf">eps_like</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span> +<span class="w"> </span><span class="sd">"""Return `eps` matching `tensor`'s dtype."""</span> + <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">finfo</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">eps</span></div> + +</pre></div> + + </article> + </div> + </div> + </main> + </div> + <footer class="md-footer"> + <div class="md-footer-nav"> + <nav class="md-footer-nav__inner md-grid"> + + + </a> + + </nav> + </div> + <div class="md-footer-meta md-typeset"> + <div class="md-footer-meta__inner md-grid"> + <div class="md-footer-copyright"> + <div class="md-footer-copyright__highlight"> + © Copyright 2021-2023, GraphNeT team. + + </div> + Created using + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. + and + <a href="https://github.com/bashtage/sphinx-material/">Material for + Sphinx</a> + </div> + </div> + </div> + </footer> + <script src="../../../_static/javascripts/application.js"></script> + <script>app.initialize({version: "1.0.4", url: {base: ".."}})</script> + </body> +</html> \ No newline at end of file diff --git a/_modules/index.html b/_modules/index.html index 89db4987e..120145b62 100644 --- a/_modules/index.html +++ b/_modules/index.html @@ -323,6 +323,11 @@ <h1 id="modules-index--page-root">All modules for which code is available</h1> <ul><li><a href="graphnet/data/constants.html">graphnet.data.constants</a></li> <li><a href="graphnet/data/dataconverter.html">graphnet.data.dataconverter</a></li> +<li><a href="graphnet/data/dataloader.html">graphnet.data.dataloader</a></li> +<li><a href="graphnet/data/dataset/dataset.html">graphnet.data.dataset.dataset</a></li> +<li><a href="graphnet/data/dataset/parquet/parquet_dataset.html">graphnet.data.dataset.parquet.parquet_dataset</a></li> +<li><a href="graphnet/data/dataset/sqlite/sqlite_dataset.html">graphnet.data.dataset.sqlite.sqlite_dataset</a></li> +<li><a href="graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html">graphnet.data.dataset.sqlite.sqlite_dataset_perturbed</a></li> <li><a href="graphnet/data/extractors/i3extractor.html">graphnet.data.extractors.i3extractor</a></li> <li><a href="graphnet/data/extractors/i3featureextractor.html">graphnet.data.extractors.i3featureextractor</a></li> <li><a href="graphnet/data/extractors/i3genericextractor.html">graphnet.data.extractors.i3genericextractor</a></li> @@ -339,18 +344,52 @@ <h1 id="modules-index--page-root">All modules for which code is available</h1> <li><a href="graphnet/data/extractors/utilities/frames.html">graphnet.data.extractors.utilities.frames</a></li> <li><a href="graphnet/data/extractors/utilities/types.html">graphnet.data.extractors.utilities.types</a></li> <li><a href="graphnet/data/parquet/parquet_dataconverter.html">graphnet.data.parquet.parquet_dataconverter</a></li> +<li><a href="graphnet/data/pipeline.html">graphnet.data.pipeline</a></li> <li><a href="graphnet/data/sqlite/sqlite_dataconverter.html">graphnet.data.sqlite.sqlite_dataconverter</a></li> <li><a href="graphnet/data/sqlite/sqlite_utilities.html">graphnet.data.sqlite.sqlite_utilities</a></li> <li><a href="graphnet/data/utilities/parquet_to_sqlite.html">graphnet.data.utilities.parquet_to_sqlite</a></li> <li><a href="graphnet/data/utilities/random.html">graphnet.data.utilities.random</a></li> <li><a href="graphnet/data/utilities/string_selection_resolver.html">graphnet.data.utilities.string_selection_resolver</a></li> +<li><a href="graphnet/deployment/i3modules/graphnet_module.html">graphnet.deployment.i3modules.graphnet_module</a></li> +<li><a href="graphnet/models/coarsening.html">graphnet.models.coarsening</a></li> +<li><a href="graphnet/models/components/layers.html">graphnet.models.components.layers</a></li> +<li><a href="graphnet/models/components/pool.html">graphnet.models.components.pool</a></li> +<li><a href="graphnet/models/detector/detector.html">graphnet.models.detector.detector</a></li> +<li><a href="graphnet/models/detector/icecube.html">graphnet.models.detector.icecube</a></li> +<li><a href="graphnet/models/detector/prometheus.html">graphnet.models.detector.prometheus</a></li> +<li><a href="graphnet/models/gnn/convnet.html">graphnet.models.gnn.convnet</a></li> +<li><a href="graphnet/models/gnn/dynedge.html">graphnet.models.gnn.dynedge</a></li> +<li><a href="graphnet/models/gnn/dynedge_jinst.html">graphnet.models.gnn.dynedge_jinst</a></li> +<li><a href="graphnet/models/gnn/dynedge_kaggle_tito.html">graphnet.models.gnn.dynedge_kaggle_tito</a></li> +<li><a href="graphnet/models/gnn/gnn.html">graphnet.models.gnn.gnn</a></li> +<li><a href="graphnet/models/graphs/edges/edges.html">graphnet.models.graphs.edges.edges</a></li> +<li><a href="graphnet/models/graphs/graph_definition.html">graphnet.models.graphs.graph_definition</a></li> +<li><a href="graphnet/models/graphs/graphs.html">graphnet.models.graphs.graphs</a></li> +<li><a href="graphnet/models/graphs/nodes/nodes.html">graphnet.models.graphs.nodes.nodes</a></li> +<li><a href="graphnet/models/model.html">graphnet.models.model</a></li> +<li><a href="graphnet/models/standard_model.html">graphnet.models.standard_model</a></li> +<li><a href="graphnet/models/task/classification.html">graphnet.models.task.classification</a></li> +<li><a href="graphnet/models/task/reconstruction.html">graphnet.models.task.reconstruction</a></li> +<li><a href="graphnet/models/task/task.html">graphnet.models.task.task</a></li> +<li><a href="graphnet/models/utils.html">graphnet.models.utils</a></li> <li><a href="graphnet/pisa/fitting.html">graphnet.pisa.fitting</a></li> <li><a href="graphnet/pisa/plotting.html">graphnet.pisa.plotting</a></li> +<li><a href="graphnet/training/callbacks.html">graphnet.training.callbacks</a></li> +<li><a href="graphnet/training/labels.html">graphnet.training.labels</a></li> +<li><a href="graphnet/training/loss_functions.html">graphnet.training.loss_functions</a></li> +<li><a href="graphnet/training/utils.html">graphnet.training.utils</a></li> <li><a href="graphnet/training/weight_fitting.html">graphnet.training.weight_fitting</a></li> <li><a href="graphnet/utilities/argparse.html">graphnet.utilities.argparse</a></li> +<li><a href="graphnet/utilities/config/base_config.html">graphnet.utilities.config.base_config</a></li> +<li><a href="graphnet/utilities/config/configurable.html">graphnet.utilities.config.configurable</a></li> +<li><a href="graphnet/utilities/config/dataset_config.html">graphnet.utilities.config.dataset_config</a></li> +<li><a href="graphnet/utilities/config/model_config.html">graphnet.utilities.config.model_config</a></li> +<li><a href="graphnet/utilities/config/parsing.html">graphnet.utilities.config.parsing</a></li> +<li><a href="graphnet/utilities/config/training_config.html">graphnet.utilities.config.training_config</a></li> <li><a href="graphnet/utilities/filesys.html">graphnet.utilities.filesys</a></li> <li><a href="graphnet/utilities/imports.html">graphnet.utilities.imports</a></li> <li><a href="graphnet/utilities/logging.html">graphnet.utilities.logging</a></li> +<li><a href="graphnet/utilities/maths.html">graphnet.utilities.maths</a></li> </ul> </article> @@ -375,7 +414,7 @@ <h1 id="modules-index--page-root">All modules for which code is available</h1> </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/about.html b/about.html index 3ebaa11ee..6422d9cf3 100644 --- a/about.html +++ b/about.html @@ -392,7 +392,7 @@ <h2 id="acknowledgements">Acknowledgements<a class="headerlink" href="#acknowled </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.constants.html b/api/graphnet.constants.html index 5cc170ec7..0a215dbf4 100644 --- a/api/graphnet.constants.html +++ b/api/graphnet.constants.html @@ -422,7 +422,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.constants.html b/api/graphnet.data.constants.html index 36edc3f62..32d33f303 100644 --- a/api/graphnet.data.constants.html +++ b/api/graphnet.data.constants.html @@ -700,7 +700,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.dataconverter.html b/api/graphnet.data.dataconverter.html index 61d9300a1..40c210634 100644 --- a/api/graphnet.data.dataconverter.html +++ b/api/graphnet.data.dataconverter.html @@ -820,7 +820,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.dataloader.html b/api/graphnet.data.dataloader.html index 8ad405a2b..98bb280f5 100644 --- a/api/graphnet.data.dataloader.html +++ b/api/graphnet.data.dataloader.html @@ -365,11 +365,54 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-data-dataloader--page-root" class="md-nav__link">dataloader</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataloader.collate_fn" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataloader.do_shuffle" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">do_shuffle()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataloader.DataLoader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataloader.DataLoader.from_dataset_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_dataset_config()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataloader.collate_fn" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataloader.do_shuffle" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">do_shuffle()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataloader.DataLoader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataloader.DataLoader.from_dataset_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_dataset_config()</span></code></a> + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -436,7 +479,22 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-data-dataloader--page-root" class="md-nav__link">dataloader</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataloader.collate_fn" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataloader.do_shuffle" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">do_shuffle()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataloader.DataLoader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataloader.DataLoader.from_dataset_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_dataset_config()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -446,8 +504,73 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="dataloader"> -<h1 id="api-graphnet-data-dataloader--page-root">dataloader<a class="headerlink" href="#api-graphnet-data-dataloader--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.data.dataloader"> +<span id="dataloader"></span><h1 id="api-graphnet-data-dataloader--page-root">dataloader<a class="headerlink" href="#api-graphnet-data-dataloader--page-root" title="Link to this heading">¶</a></h1> +<p>Base <cite>Dataloader</cite> class(es) used in <cite>graphnet</cite>.</p> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.data.dataloader.collate_fn"> +<span class="sig-prename descclassname"><span class="pre">graphnet.data.dataloader.</span></span><span class="sig-name descname"><span class="pre">collate_fn</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">graphs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataloader.html#collate_fn"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataloader.collate_fn" title="Link to this definition">¶</a></dt> +<dd><p>Remove graphs with less than two DOM hits.</p> +<p>Should not occur in “production.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Batch</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>graphs</strong> (<em>List</em><em>[</em><em>Data</em><em>]</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.data.dataloader.do_shuffle"> +<span class="sig-prename descclassname"><span class="pre">graphnet.data.dataloader.</span></span><span class="sig-name descname"><span class="pre">do_shuffle</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">selection_name</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataloader.html#do_shuffle"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataloader.do_shuffle" title="Link to this definition">¶</a></dt> +<dd><p>Check whether to shuffle selection with name <cite>selection_name</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>selection_name</strong> (<em>str</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.data.dataloader.DataLoader"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.dataloader.</span></span><span class="sig-name descname"><span class="pre">DataLoader</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">dataset</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shuffle</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_workers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">persistent_workers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">collate_fn=<function</span> <span class="pre">collate_fn></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prefetch_factor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">**kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataloader.html#DataLoader"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataloader.DataLoader" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">DataLoader</span></code></p> +<p>Class for loading data from a <cite>Dataset</cite>.</p> +<p>Construct <cite>DataLoader</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>dataset</strong> (<em>Dataset</em><em>[</em><em>T_co</em><em>]</em>) – </p></li> +<li><p><strong>batch_size</strong> (<em>int</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>shuffle</strong> (<em>bool</em>) – </p></li> +<li><p><strong>num_workers</strong> (<em>int</em>) – </p></li> +<li><p><strong>persistent_workers</strong> (<em>bool</em>) – </p></li> +<li><p><strong>collate_fn</strong> (<em>Callable</em>) – </p></li> +<li><p><strong>prefetch_factor</strong> (<em>int</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.data.dataloader.DataLoader.from_dataset_config"> +<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">from_dataset_config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">config</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataloader.html#DataLoader.from_dataset_config"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataloader.DataLoader.from_dataset_config" title="Link to this definition">¶</a></dt> +<dd><p>Construct <cite>DataLoader`s based on selections in `DatasetConfig</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<a class="reference internal" href="#graphnet.data.dataloader.DataLoader" title="graphnet.data.dataloader.DataLoader"><code class="xref py py-class docutils literal notranslate"><span class="pre">DataLoader</span></code></a>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <a class="reference internal" href="#graphnet.data.dataloader.DataLoader" title="graphnet.data.dataloader.DataLoader"><code class="xref py py-class docutils literal notranslate"><span class="pre">DataLoader</span></code></a>]]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>config</strong> (<a class="reference internal" href="graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig" title="graphnet.utilities.config.dataset_config.DatasetConfig"><em>DatasetConfig</em></a>) – </p></li> +<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -497,7 +620,7 @@ <h1 id="api-graphnet-data-dataloader--page-root">dataloader<a class="headerlink" </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.dataset.dataset.html b/api/graphnet.data.dataset.dataset.html index 0750d2bf5..ac7ed7ec2 100644 --- a/api/graphnet.data.dataset.dataset.html +++ b/api/graphnet.data.dataset.dataset.html @@ -336,11 +336,117 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-data-dataset-dataset--page-root" class="md-nav__link">dataset</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.ColumnMissingException" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ColumnMissingException</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.load_module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load_module()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.parse_graph_definition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">parse_graph_definition()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Dataset</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.concatenate" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">concatenate()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.path" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">path</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth_table</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.add_label" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">add_label()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.EnsembleDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnsembleDataset</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.dataset.ColumnMissingException" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ColumnMissingException</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.dataset.load_module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load_module()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.dataset.parse_graph_definition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">parse_graph_definition()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.dataset.Dataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Dataset</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.dataset.Dataset.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.dataset.Dataset.concatenate" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">concatenate()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.dataset.Dataset.path" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">path</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.dataset.Dataset.truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth_table</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.dataset.Dataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.dataset.Dataset.add_label" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">add_label()</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.dataset.EnsembleDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnsembleDataset</span></code></a> + + + </li></ul> + </li></ul> </li> @@ -458,7 +564,36 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-data-dataset-dataset--page-root" class="md-nav__link">dataset</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.ColumnMissingException" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ColumnMissingException</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.load_module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load_module()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.parse_graph_definition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">parse_graph_definition()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Dataset</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.concatenate" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">concatenate()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.path" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">path</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth_table</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.Dataset.add_label" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">add_label()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.data.dataset.dataset.EnsembleDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnsembleDataset</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -468,8 +603,197 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="dataset"> -<h1 id="api-graphnet-data-dataset-dataset--page-root">dataset<a class="headerlink" href="#api-graphnet-data-dataset-dataset--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.data.dataset.dataset"> +<span id="dataset"></span><h1 id="api-graphnet-data-dataset-dataset--page-root">dataset<a class="headerlink" href="#api-graphnet-data-dataset-dataset--page-root" title="Link to this heading">¶</a></h1> +<p>Base <a class="reference internal" href="#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code></a> class(es) used in GraphNeT.</p> +<dl class="py exception"> +<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.ColumnMissingException"> +<em class="property"><span class="pre">exception</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.dataset.</span></span><span class="sig-name descname"><span class="pre">ColumnMissingException</span></span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#ColumnMissingException"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.ColumnMissingException" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">Exception</span></code></p> +<p>Exception to indicate a missing column in a dataset.</p> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.load_module"> +<span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.dataset.</span></span><span class="sig-name descname"><span class="pre">load_module</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">class_name</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#load_module"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.load_module" title="Link to this definition">¶</a></dt> +<dd><p>Load graphnet module from string name.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>class_name</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – name of class</p> +</dd> +<dt class="field-even">Return type<span class="colon">:</span></dt> +<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Type</span></code></p> +</dd> +<dt class="field-odd">Returns<span class="colon">:</span></dt> +<dd class="field-odd"><p>graphnet module.</p> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.parse_graph_definition"> +<span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.dataset.</span></span><span class="sig-name descname"><span class="pre">parse_graph_definition</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cfg</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#parse_graph_definition"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.parse_graph_definition" title="Link to this definition">¶</a></dt> +<dd><p>Construct GraphDefinition from DatasetConfig.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>cfg</strong> (<em>dict</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.Dataset"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.dataset.</span></span><span class="sig-name descname"><span class="pre">Dataset</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemaps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">index_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">string_selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_default_value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">seed</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#Dataset"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.Dataset" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.utilities.logging.html#graphnet.utilities.logging.Logger" title="graphnet.utilities.logging.Logger"><code class="xref py py-class docutils literal notranslate"><span class="pre">Logger</span></code></a>, <a class="reference internal" href="graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable" title="graphnet.utilities.config.configurable.Configurable"><code class="xref py py-class docutils literal notranslate"><span class="pre">Configurable</span></code></a>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">ABC</span></code></p> +<p>Base Dataset class for reading from any intermediate file format.</p> +<p>Construct Dataset.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>path</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Path to the file(s) from which this <cite>Dataset</cite> should read.</p></li> +<li><p><strong>pulsemaps</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Name(s) of the pulse map series that should be used to +construct the nodes on the individual graph objects, and their +features. Multiple pulse series maps can be used, e.g., when +different DOM types are stored in different maps.</p></li> +<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of columns in the input files that should be used as +node features on the graph objects.</p></li> +<li><p><strong>truth</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of event-level columns in the input files that should +be used added as attributes on the graph objects.</p></li> +<li><p><strong>node_truth</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – List of node-level columns in the input files that +should be used added as attributes on the graph objects.</p></li> +<li><p><strong>index_column</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'event_no'</span></code>) – Name of the column in the input files that contains +unique indicies to identify and map events across tables.</p></li> +<li><p><strong>truth_table</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'truth'</span></code>) – Name of the table containing event-level truth +information.</p></li> +<li><p><strong>node_truth_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing node-level truth +information.</p></li> +<li><p><strong>string_selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Subset of strings for which data should be read +and used to construct graph objects. Defaults to None, meaning +all strings for which data exists are used.</p></li> +<li><p><strong>selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The events that should be read. This can be given either +as list of indicies (in <cite>index_column</cite>); or a string-based +selection used to query the <cite>Dataset</cite> for events passing the +selection. Defaults to None, meaning that all events in the +input files are read.</p></li> +<li><p><strong>dtype</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">dtype</span></code>, default: <code class="docutils literal notranslate"><span class="pre">torch.float32</span></code>) – Type of the feature tensor on the graph objects returned.</p></li> +<li><p><strong>loss_weight_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing per-event loss +weights.</p></li> +<li><p><strong>loss_weight_column</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the column in <cite>loss_weight_table</cite> +containing per-event loss weights. This is also the name of the +corresponding attribute assigned to the graph object.</p></li> +<li><p><strong>loss_weight_default_value</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Default per-event loss weight. +NOTE: This default value is only applied when +<cite>loss_weight_table</cite> and <cite>loss_weight_column</cite> are specified, and +in this case to events with no value in the corresponding +table/column. That is, if no per-event loss weight table/column +is provided, this value is ignored. Defaults to None.</p></li> +<li><p><strong>seed</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Random number generator seed, used for selecting a random +subset of events when resolving a string-based selection (e.g., +<cite>“10000 random events ~ event_no % 5 > 0”</cite> or <cite>“20% random +events ~ event_no % 5 > 0”</cite>).</p></li> +<li><p><strong>graph_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a>) – Method that defines the graph representation.</p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.Dataset.from_config"> +<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">from_config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">source</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#Dataset.from_config"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.Dataset.from_config" title="Link to this definition">¶</a></dt> +<dd><p>Construct <cite>Dataset</cite> instance from <cite>source</cite> configuration.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<a class="reference internal" href="#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code></a>, <a class="reference internal" href="#graphnet.data.dataset.dataset.EnsembleDataset" title="graphnet.data.dataset.dataset.EnsembleDataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">EnsembleDataset</span></code></a>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <a class="reference internal" href="#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code></a>], <code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <a class="reference internal" href="#graphnet.data.dataset.dataset.EnsembleDataset" title="graphnet.data.dataset.dataset.EnsembleDataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">EnsembleDataset</span></code></a>]]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>source</strong> (<a class="reference internal" href="graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig" title="graphnet.utilities.config.dataset_config.DatasetConfig"><em>DatasetConfig</em></a><em> | </em><em>str</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.Dataset.concatenate"> +<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">concatenate</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">datasets</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#Dataset.concatenate"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.Dataset.concatenate" title="Link to this definition">¶</a></dt> +<dd><p>Concatenate multiple <a href="#id1"><span class="problematic" id="id2">`</span></a>Dataset`s into one instance.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><a class="reference internal" href="#graphnet.data.dataset.dataset.EnsembleDataset" title="graphnet.data.dataset.dataset.EnsembleDataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">EnsembleDataset</span></code></a></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>datasets</strong> (<em>List</em><em>[</em><a class="reference internal" href="#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><em>Dataset</em></a><em>]</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.Dataset.path"> +<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">path</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">str</span><span class="w"> </span><span class="p"><span class="pre">|</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#graphnet.data.dataset.dataset.Dataset.path" title="Link to this definition">¶</a></dt> +<dd><p>Path to the file(s) from which this <cite>Dataset</cite> reads.</p> +</dd></dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.Dataset.truth_table"> +<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">truth_table</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">str</span></em><a class="headerlink" href="#graphnet.data.dataset.dataset.Dataset.truth_table" title="Link to this definition">¶</a></dt> +<dd><p>Name of the table containing event-level truth information.</p> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.Dataset.query_table"> +<em class="property"><span class="pre">abstract</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">query_table</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">sequential_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#Dataset.query_table"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.Dataset.query_table" title="Link to this definition">¶</a></dt> +<dd><p>Query a table at a specific index, optionally with some selection.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>table</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – Table to be queried.</p></li> +<li><p><strong>columns</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – Columns to read out.</p></li> +<li><p><strong>sequential_index</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Sequentially numbered index +(i.e. in [0,len(self))) of the event to query. This _may_ +differ from the indexation used in <cite>self._indices</cite>. If no value +is provided, the entire column is returned.</p></li> +<li><p><strong>selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Selection to be imposed before reading out data. +Defaults to None.</p></li> +</ul> +</dd> +<dt class="field-even">Return type<span class="colon">:</span></dt> +<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">...</span></code>]]</p> +</dd> +<dt class="field-odd">Returns<span class="colon">:</span></dt> +<dd class="field-odd"><p><dl class="simple"> +<dt>List of tuples containing the values in <cite>columns</cite>. If the <cite>table</cite></dt><dd><p>contains only scalar data for <cite>columns</cite>, a list of length 1 is +returned</p> +</dd> +</dl> +</p> +</dd> +<dt class="field-even">Raises<span class="colon">:</span></dt> +<dd class="field-even"><p><a class="reference internal" href="#graphnet.data.dataset.dataset.ColumnMissingException" title="graphnet.data.dataset.dataset.ColumnMissingException"><strong>ColumnMissingException</strong></a> – If one or more element in <cite>columns</cite> is not + present in <cite>table</cite>.</p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.Dataset.add_label"> +<span class="sig-name descname"><span class="pre">add_label</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">fn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">key</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#Dataset.add_label"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.Dataset.add_label" title="Link to this definition">¶</a></dt> +<dd><p>Add custom graph label define using function <cite>fn</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>fn</strong> (<em>Callable</em><em>[</em><em>[</em><em>Data</em><em>]</em><em>, </em><em>Any</em><em>]</em>) – </p></li> +<li><p><strong>key</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.data.dataset.dataset.EnsembleDataset"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.dataset.</span></span><span class="sig-name descname"><span class="pre">EnsembleDataset</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">datasets</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/dataset.html#EnsembleDataset"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.dataset.EnsembleDataset" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">ConcatDataset</span></code></p> +<p>Construct a single dataset from a collection of datasets.</p> +<p>Construct a single dataset from a collection of datasets.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>datasets</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Iterable</span></code>[<a class="reference internal" href="#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code></a>]) – A collection of Datasets</p> +</dd> +</dl> +</dd></dl> </section> @@ -519,7 +843,7 @@ <h1 id="api-graphnet-data-dataset-dataset--page-root">dataset<a class="headerlin </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.dataset.html b/api/graphnet.data.dataset.html index 9c721d294..a4b095f8e 100644 --- a/api/graphnet.data.dataset.html +++ b/api/graphnet.data.dataset.html @@ -467,8 +467,9 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="dataset"> -<h1 id="api-graphnet-data-dataset--page-root">dataset<a class="headerlink" href="#api-graphnet-data-dataset--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.data.dataset"> +<span id="dataset"></span><h1 id="api-graphnet-data-dataset--page-root">dataset<a class="headerlink" href="#api-graphnet-data-dataset--page-root" title="Link to this heading">¶</a></h1> +<p>Dataset classes for training in GraphNeT.</p> <p><h2> Subpackages </h2></p> <div class="toctree-wrapper compound"> <ul> @@ -486,7 +487,14 @@ <h1 id="api-graphnet-data-dataset--page-root">dataset<a class="headerlink" href= <p><h2> Submodules </h2></p> <div class="toctree-wrapper compound"> <ul> -<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.dataset.html">dataset</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.dataset.html">dataset</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.ColumnMissingException"><code class="docutils literal notranslate"><span class="pre">ColumnMissingException</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.load_module"><code class="docutils literal notranslate"><span class="pre">load_module()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.parse_graph_definition"><code class="docutils literal notranslate"><span class="pre">parse_graph_definition()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset"><code class="docutils literal notranslate"><span class="pre">Dataset</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.EnsembleDataset"><code class="docutils literal notranslate"><span class="pre">EnsembleDataset</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -538,7 +546,7 @@ <h1 id="api-graphnet-data-dataset--page-root">dataset<a class="headerlink" href= </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.dataset.parquet.html b/api/graphnet.data.dataset.parquet.html index 873efa2c0..492684fad 100644 --- a/api/graphnet.data.dataset.parquet.html +++ b/api/graphnet.data.dataset.parquet.html @@ -475,12 +475,16 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="parquet"> -<h1 id="api-graphnet-data-dataset-parquet--page-root">parquet<a class="headerlink" href="#api-graphnet-data-dataset-parquet--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.data.dataset.parquet"> +<span id="parquet"></span><h1 id="api-graphnet-data-dataset-parquet--page-root">parquet<a class="headerlink" href="#api-graphnet-data-dataset-parquet--page-root" title="Link to this heading">¶</a></h1> +<p>Datasets using parquet backend.</p> <p><h2> Submodules </h2></p> <div class="toctree-wrapper compound"> <ul> -<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.parquet.parquet_dataset.html">parquet_dataset</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.parquet.parquet_dataset.html">parquet_dataset</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.parquet.parquet_dataset.html#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset"><code class="docutils literal notranslate"><span class="pre">ParquetDataset</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -532,7 +536,7 @@ <h1 id="api-graphnet-data-dataset-parquet--page-root">parquet<a class="headerlin </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.dataset.parquet.parquet_dataset.html b/api/graphnet.data.dataset.parquet.parquet_dataset.html index 67cfc6f08..3598486cf 100644 --- a/api/graphnet.data.dataset.parquet.parquet_dataset.html +++ b/api/graphnet.data.dataset.parquet.parquet_dataset.html @@ -328,11 +328,36 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-data-dataset-parquet-parquet-dataset--page-root" class="md-nav__link">parquet_dataset</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ParquetDataset</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ParquetDataset</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a> + </li></ul> + + </li></ul> + </li></ul> </li> @@ -466,7 +491,18 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-data-dataset-parquet-parquet-dataset--page-root" class="md-nav__link">parquet_dataset</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ParquetDataset</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -476,8 +512,82 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="parquet-dataset"> -<h1 id="api-graphnet-data-dataset-parquet-parquet-dataset--page-root">parquet_dataset<a class="headerlink" href="#api-graphnet-data-dataset-parquet-parquet-dataset--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.data.dataset.parquet.parquet_dataset"> +<span id="parquet-dataset"></span><h1 id="api-graphnet-data-dataset-parquet-parquet-dataset--page-root">parquet_dataset<a class="headerlink" href="#api-graphnet-data-dataset-parquet-parquet-dataset--page-root" title="Link to this heading">¶</a></h1> +<p><cite>Dataset</cite> class(es) for reading from Parquet files.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.parquet.parquet_dataset.</span></span><span class="sig-name descname"><span class="pre">ParquetDataset</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemaps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">index_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">string_selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_default_value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">seed</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/parquet/parquet_dataset.html#ParquetDataset"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code></a></p> +<p>Pytorch dataset for reading from Parquet files.</p> +<p>Construct Dataset.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>path</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Path to the file(s) from which this <cite>Dataset</cite> should read.</p></li> +<li><p><strong>pulsemaps</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Name(s) of the pulse map series that should be used to +construct the nodes on the individual graph objects, and their +features. Multiple pulse series maps can be used, e.g., when +different DOM types are stored in different maps.</p></li> +<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of columns in the input files that should be used as +node features on the graph objects.</p></li> +<li><p><strong>truth</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of event-level columns in the input files that should +be used added as attributes on the graph objects.</p></li> +<li><p><strong>node_truth</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – List of node-level columns in the input files that +should be used added as attributes on the graph objects.</p></li> +<li><p><strong>index_column</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'event_no'</span></code>) – Name of the column in the input files that contains +unique indicies to identify and map events across tables.</p></li> +<li><p><strong>truth_table</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'truth'</span></code>) – Name of the table containing event-level truth +information.</p></li> +<li><p><strong>node_truth_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing node-level truth +information.</p></li> +<li><p><strong>string_selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Subset of strings for which data should be read +and used to construct graph objects. Defaults to None, meaning +all strings for which data exists are used.</p></li> +<li><p><strong>selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The events that should be read. This can be given either +as list of indicies (in <cite>index_column</cite>); or a string-based +selection used to query the <cite>Dataset</cite> for events passing the +selection. Defaults to None, meaning that all events in the +input files are read.</p></li> +<li><p><strong>dtype</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">dtype</span></code>, default: <code class="docutils literal notranslate"><span class="pre">torch.float32</span></code>) – Type of the feature tensor on the graph objects returned.</p></li> +<li><p><strong>loss_weight_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing per-event loss +weights.</p></li> +<li><p><strong>loss_weight_column</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the column in <cite>loss_weight_table</cite> +containing per-event loss weights. This is also the name of the +corresponding attribute assigned to the graph object.</p></li> +<li><p><strong>loss_weight_default_value</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Default per-event loss weight. +NOTE: This default value is only applied when +<cite>loss_weight_table</cite> and <cite>loss_weight_column</cite> are specified, and +in this case to events with no value in the corresponding +table/column. That is, if no per-event loss weight table/column +is provided, this value is ignored. Defaults to None.</p></li> +<li><p><strong>seed</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Random number generator seed, used for selecting a random +subset of events when resolving a string-based selection (e.g., +<cite>“10000 random events ~ event_no % 5 > 0”</cite> or <cite>“20% random +events ~ event_no % 5 > 0”</cite>).</p></li> +<li><p><strong>graph_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a>) – Method that defines the graph representation.</p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table"> +<span class="sig-name descname"><span class="pre">query_table</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">sequential_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/parquet/parquet_dataset.html#ParquetDataset.query_table"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table" title="Link to this definition">¶</a></dt> +<dd><p>Query table at a specific index, optionally with some selection.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">...</span></code>]]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>table</strong> (<em>str</em>) – </p></li> +<li><p><strong>columns</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>str</em>) – </p></li> +<li><p><strong>sequential_index</strong> (<em>int</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>selection</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -527,7 +637,7 @@ <h1 id="api-graphnet-data-dataset-parquet-parquet-dataset--page-root">parquet_da </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.dataset.sqlite.html b/api/graphnet.data.dataset.sqlite.html index 9cfc9a2a5..8ea0f5783 100644 --- a/api/graphnet.data.dataset.sqlite.html +++ b/api/graphnet.data.dataset.sqlite.html @@ -482,13 +482,20 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="sqlite"> -<h1 id="api-graphnet-data-dataset-sqlite--page-root">sqlite<a class="headerlink" href="#api-graphnet-data-dataset-sqlite--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.data.dataset.sqlite"> +<span id="sqlite"></span><h1 id="api-graphnet-data-dataset-sqlite--page-root">sqlite<a class="headerlink" href="#api-graphnet-data-dataset-sqlite--page-root" title="Link to this heading">¶</a></h1> +<p>Datasets using SQLite backend.</p> <p><h2> Submodules </h2></p> <div class="toctree-wrapper compound"> <ul> -<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.sqlite.sqlite_dataset.html">sqlite_dataset</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html">sqlite_dataset_perturbed</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.sqlite.sqlite_dataset.html">sqlite_dataset</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.sqlite.sqlite_dataset.html#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset"><code class="docutils literal notranslate"><span class="pre">SQLiteDataset</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html">sqlite_dataset_perturbed</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html#graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed"><code class="docutils literal notranslate"><span class="pre">SQLiteDatasetPerturbed</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -540,7 +547,7 @@ <h1 id="api-graphnet-data-dataset-sqlite--page-root">sqlite<a class="headerlink" </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.dataset.sqlite.sqlite_dataset.html b/api/graphnet.data.dataset.sqlite.sqlite_dataset.html index 36b0988a6..7d25db896 100644 --- a/api/graphnet.data.dataset.sqlite.sqlite_dataset.html +++ b/api/graphnet.data.dataset.sqlite.sqlite_dataset.html @@ -335,11 +335,36 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-data-dataset-sqlite-sqlite-dataset--page-root" class="md-nav__link">sqlite_dataset</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">SQLiteDataset</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">SQLiteDataset</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a> + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -473,7 +498,18 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-data-dataset-sqlite-sqlite-dataset--page-root" class="md-nav__link">sqlite_dataset</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">SQLiteDataset</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">query_table()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -483,8 +519,82 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="sqlite-dataset"> -<h1 id="api-graphnet-data-dataset-sqlite-sqlite-dataset--page-root">sqlite_dataset<a class="headerlink" href="#api-graphnet-data-dataset-sqlite-sqlite-dataset--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.data.dataset.sqlite.sqlite_dataset"> +<span id="sqlite-dataset"></span><h1 id="api-graphnet-data-dataset-sqlite-sqlite-dataset--page-root">sqlite_dataset<a class="headerlink" href="#api-graphnet-data-dataset-sqlite-sqlite-dataset--page-root" title="Link to this heading">¶</a></h1> +<p><cite>Dataset</cite> class(es) for reading data from SQLite databases.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.sqlite.sqlite_dataset.</span></span><span class="sig-name descname"><span class="pre">SQLiteDataset</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemaps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">index_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">string_selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_default_value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">seed</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html#SQLiteDataset"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code></a></p> +<p>Pytorch dataset for reading data from SQLite databases.</p> +<p>Construct Dataset.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>path</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Path to the file(s) from which this <cite>Dataset</cite> should read.</p></li> +<li><p><strong>pulsemaps</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Name(s) of the pulse map series that should be used to +construct the nodes on the individual graph objects, and their +features. Multiple pulse series maps can be used, e.g., when +different DOM types are stored in different maps.</p></li> +<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of columns in the input files that should be used as +node features on the graph objects.</p></li> +<li><p><strong>truth</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of event-level columns in the input files that should +be used added as attributes on the graph objects.</p></li> +<li><p><strong>node_truth</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – List of node-level columns in the input files that +should be used added as attributes on the graph objects.</p></li> +<li><p><strong>index_column</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'event_no'</span></code>) – Name of the column in the input files that contains +unique indicies to identify and map events across tables.</p></li> +<li><p><strong>truth_table</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'truth'</span></code>) – Name of the table containing event-level truth +information.</p></li> +<li><p><strong>node_truth_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing node-level truth +information.</p></li> +<li><p><strong>string_selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Subset of strings for which data should be read +and used to construct graph objects. Defaults to None, meaning +all strings for which data exists are used.</p></li> +<li><p><strong>selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The events that should be read. This can be given either +as list of indicies (in <cite>index_column</cite>); or a string-based +selection used to query the <cite>Dataset</cite> for events passing the +selection. Defaults to None, meaning that all events in the +input files are read.</p></li> +<li><p><strong>dtype</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">dtype</span></code>, default: <code class="docutils literal notranslate"><span class="pre">torch.float32</span></code>) – Type of the feature tensor on the graph objects returned.</p></li> +<li><p><strong>loss_weight_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing per-event loss +weights.</p></li> +<li><p><strong>loss_weight_column</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the column in <cite>loss_weight_table</cite> +containing per-event loss weights. This is also the name of the +corresponding attribute assigned to the graph object.</p></li> +<li><p><strong>loss_weight_default_value</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Default per-event loss weight. +NOTE: This default value is only applied when +<cite>loss_weight_table</cite> and <cite>loss_weight_column</cite> are specified, and +in this case to events with no value in the corresponding +table/column. That is, if no per-event loss weight table/column +is provided, this value is ignored. Defaults to None.</p></li> +<li><p><strong>seed</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Random number generator seed, used for selecting a random +subset of events when resolving a string-based selection (e.g., +<cite>“10000 random events ~ event_no % 5 > 0”</cite> or <cite>“20% random +events ~ event_no % 5 > 0”</cite>).</p></li> +<li><p><strong>graph_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a>) – Method that defines the graph representation.</p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table"> +<span class="sig-name descname"><span class="pre">query_table</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">sequential_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html#SQLiteDataset.query_table"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table" title="Link to this definition">¶</a></dt> +<dd><p>Query table at a specific index, optionally with some selection.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">...</span></code>]]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>table</strong> (<em>str</em>) – </p></li> +<li><p><strong>columns</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>str</em>) – </p></li> +<li><p><strong>sequential_index</strong> (<em>int</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>selection</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -534,7 +644,7 @@ <h1 id="api-graphnet-data-dataset-sqlite-sqlite-dataset--page-root">sqlite_datas </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html b/api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html index e5b66ee43..f46386964 100644 --- a/api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html +++ b/api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html @@ -342,11 +342,25 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root" class="md-nav__link">sqlite_dataset_perturbed</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">SQLiteDatasetPerturbed</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">SQLiteDatasetPerturbed</span></code></a> + </li></ul> + </li></ul> </li> @@ -473,7 +487,14 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root" class="md-nav__link">sqlite_dataset_perturbed</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">SQLiteDatasetPerturbed</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -483,8 +504,65 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="sqlite-dataset-perturbed"> -<h1 id="api-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root">sqlite_dataset_perturbed<a class="headerlink" href="#api-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed"> +<span id="sqlite-dataset-perturbed"></span><h1 id="api-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root">sqlite_dataset_perturbed<a class="headerlink" href="#api-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root" title="Link to this heading">¶</a></h1> +<p><cite>Dataset</cite> class(es) for reading perturbed data from SQLite databases.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.</span></span><span class="sig-name descname"><span class="pre">SQLiteDatasetPerturbed</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemaps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">perturbation_dict</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">index_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">string_selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_default_value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">seed</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html#SQLiteDatasetPerturbed"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.data.dataset.sqlite.sqlite_dataset.html#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset" title="graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">SQLiteDataset</span></code></a></p> +<p>Pytorch dataset for reading perturbed data from SQLite databases.</p> +<p>This including a pre-processing step, where the input data is randomly +perturbed according to given per-feature “noise” levels. This is intended +to test the stability of a trained model under small changes to the input +parameters.</p> +<p>Construct SQLiteDatasetPerturbed.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>path</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Path to the file(s) from which this <cite>Dataset</cite> should read.</p></li> +<li><p><strong>pulsemaps</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]]) – Name(s) of the pulse map series that should be used to +construct the nodes on the individual graph objects, and their +features. Multiple pulse series maps can be used, e.g., when +different DOM types are stored in different maps.</p></li> +<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of columns in the input files that should be used as +node features on the graph objects.</p></li> +<li><p><strong>truth</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of event-level columns in the input files that should +be used added as attributes on the graph objects.</p></li> +<li><p><strong>perturbation_dict</strong> (<em>Dict</em><em>[</em><em>str</em><em>, </em><em>float</em><em>]</em>) – Dictionary mapping a feature +name to a standard deviation according to which the values for +this feature should be randomly perturbed.</p></li> +<li><p><strong>node_truth</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – List of node-level columns in the input files that +should be used added as attributes on the graph objects.</p></li> +<li><p><strong>index_column</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'event_no'</span></code>) – Name of the column in the input files that contains +unique indicies to identify and map events across tables.</p></li> +<li><p><strong>truth_table</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'truth'</span></code>) – Name of the table containing event-level truth +information.</p></li> +<li><p><strong>node_truth_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing node-level truth +information.</p></li> +<li><p><strong>string_selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Subset of strings for which data should be read +and used to construct graph objects. Defaults to None, meaning +all strings for which data exists are used.</p></li> +<li><p><strong>selection</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – List of indicies (in <cite>index_column</cite>) of the events in +the input files that should be read. Defaults to None, meaning +that all events in the input files are read.</p></li> +<li><p><strong>dtype</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">dtype</span></code>, default: <code class="docutils literal notranslate"><span class="pre">torch.float32</span></code>) – Type of the feature tensor on the graph objects returned.</p></li> +<li><p><strong>loss_weight_table</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the table containing per-event loss +weights.</p></li> +<li><p><strong>loss_weight_column</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the column in <cite>loss_weight_table</cite> +containing per-event loss weights. This is also the name of the +corresponding attribute assigned to the graph object.</p></li> +<li><p><strong>loss_weight_default_value</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Default per-event loss weight. +NOTE: This default value is only applied when +<cite>loss_weight_table</cite> and <cite>loss_weight_column</cite> are specified, and +in this case to events with no value in the corresponding +table/column. That is, if no per-event loss weight table/column +is provided, this value is ignored. Defaults to None.</p></li> +<li><p><strong>seed</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Generator</span></code>, <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional seed for random number generation. Defaults to None.</p></li> +</ul> +</dd> +</dl> +</dd></dl> </section> @@ -534,7 +612,7 @@ <h1 id="api-graphnet-data-dataset-sqlite-sqlite-dataset-perturbed--page-root">sq </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.html b/api/graphnet.data.extractors.html index df9ce380e..5810ea8fc 100644 --- a/api/graphnet.data.extractors.html +++ b/api/graphnet.data.extractors.html @@ -658,7 +658,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.i3extractor.html b/api/graphnet.data.extractors.i3extractor.html index 884ec4690..4671c09f7 100644 --- a/api/graphnet.data.extractors.i3extractor.html +++ b/api/graphnet.data.extractors.i3extractor.html @@ -732,7 +732,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.i3featureextractor.html b/api/graphnet.data.extractors.i3featureextractor.html index fd27d93af..520605dfe 100644 --- a/api/graphnet.data.extractors.i3featureextractor.html +++ b/api/graphnet.data.extractors.i3featureextractor.html @@ -720,7 +720,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.i3genericextractor.html b/api/graphnet.data.extractors.i3genericextractor.html index fffb0ae7a..2915de7ce 100644 --- a/api/graphnet.data.extractors.i3genericextractor.html +++ b/api/graphnet.data.extractors.i3genericextractor.html @@ -639,7 +639,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.i3hybridrecoextractor.html b/api/graphnet.data.extractors.i3hybridrecoextractor.html index c13a47727..e23c93bd9 100644 --- a/api/graphnet.data.extractors.i3hybridrecoextractor.html +++ b/api/graphnet.data.extractors.i3hybridrecoextractor.html @@ -623,7 +623,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html b/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html index d87d2a70a..f9a853695 100644 --- a/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html +++ b/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html @@ -626,7 +626,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.i3particleextractor.html b/api/graphnet.data.extractors.i3particleextractor.html index ed80239fc..fbdb86f6e 100644 --- a/api/graphnet.data.extractors.i3particleextractor.html +++ b/api/graphnet.data.extractors.i3particleextractor.html @@ -625,7 +625,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.i3pisaextractor.html b/api/graphnet.data.extractors.i3pisaextractor.html index 82a7edbf0..4471bf09a 100644 --- a/api/graphnet.data.extractors.i3pisaextractor.html +++ b/api/graphnet.data.extractors.i3pisaextractor.html @@ -623,7 +623,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.i3quesoextractor.html b/api/graphnet.data.extractors.i3quesoextractor.html index 155c83c34..649bb0171 100644 --- a/api/graphnet.data.extractors.i3quesoextractor.html +++ b/api/graphnet.data.extractors.i3quesoextractor.html @@ -626,7 +626,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.i3retroextractor.html b/api/graphnet.data.extractors.i3retroextractor.html index 579e01cb4..86f076405 100644 --- a/api/graphnet.data.extractors.i3retroextractor.html +++ b/api/graphnet.data.extractors.i3retroextractor.html @@ -623,7 +623,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.i3splinempeextractor.html b/api/graphnet.data.extractors.i3splinempeextractor.html index fd7c55164..f229b0090 100644 --- a/api/graphnet.data.extractors.i3splinempeextractor.html +++ b/api/graphnet.data.extractors.i3splinempeextractor.html @@ -623,7 +623,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.i3truthextractor.html b/api/graphnet.data.extractors.i3truthextractor.html index 8e211b71a..213393e35 100644 --- a/api/graphnet.data.extractors.i3truthextractor.html +++ b/api/graphnet.data.extractors.i3truthextractor.html @@ -630,7 +630,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.i3tumextractor.html b/api/graphnet.data.extractors.i3tumextractor.html index 4b4ad2d29..326699de0 100644 --- a/api/graphnet.data.extractors.i3tumextractor.html +++ b/api/graphnet.data.extractors.i3tumextractor.html @@ -623,7 +623,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.utilities.collections.html b/api/graphnet.data.extractors.utilities.collections.html index d963c742e..0970c915a 100644 --- a/api/graphnet.data.extractors.utilities.collections.html +++ b/api/graphnet.data.extractors.utilities.collections.html @@ -708,7 +708,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.utilities.frames.html b/api/graphnet.data.extractors.utilities.frames.html index f895882b1..2fcdda991 100644 --- a/api/graphnet.data.extractors.utilities.frames.html +++ b/api/graphnet.data.extractors.utilities.frames.html @@ -709,7 +709,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.utilities.html b/api/graphnet.data.extractors.utilities.html index da9334938..d7361ced0 100644 --- a/api/graphnet.data.extractors.utilities.html +++ b/api/graphnet.data.extractors.utilities.html @@ -640,7 +640,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.extractors.utilities.types.html b/api/graphnet.data.extractors.utilities.types.html index 6c8154f02..10266dd2d 100644 --- a/api/graphnet.data.extractors.utilities.types.html +++ b/api/graphnet.data.extractors.utilities.types.html @@ -867,7 +867,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.html b/api/graphnet.data.html index dff3daeae..57b5f245a 100644 --- a/api/graphnet.data.html +++ b/api/graphnet.data.html @@ -507,8 +507,16 @@ <li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataconverter.html#graphnet.data.dataconverter.DataConverter"><code class="docutils literal notranslate"><span class="pre">DataConverter</span></code></a></li> </ul> </li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataloader.html">dataloader</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.data.pipeline.html">pipeline</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.data.dataloader.html">dataloader</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataloader.html#graphnet.data.dataloader.collate_fn"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataloader.html#graphnet.data.dataloader.do_shuffle"><code class="docutils literal notranslate"><span class="pre">do_shuffle()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.data.dataloader.html#graphnet.data.dataloader.DataLoader"><code class="docutils literal notranslate"><span class="pre">DataLoader</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.data.pipeline.html">pipeline</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.data.pipeline.html#graphnet.data.pipeline.InSQLitePipeline"><code class="docutils literal notranslate"><span class="pre">InSQLitePipeline</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -560,7 +568,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.parquet.html b/api/graphnet.data.parquet.html index 31148efc3..ebb6a3f5f 100644 --- a/api/graphnet.data.parquet.html +++ b/api/graphnet.data.parquet.html @@ -514,7 +514,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.parquet.parquet_dataconverter.html b/api/graphnet.data.parquet.parquet_dataconverter.html index 3b17b9603..a0f810016 100644 --- a/api/graphnet.data.parquet.parquet_dataconverter.html +++ b/api/graphnet.data.parquet.parquet_dataconverter.html @@ -665,7 +665,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.pipeline.html b/api/graphnet.data.pipeline.html index 633c55e35..3e14a0e83 100644 --- a/api/graphnet.data.pipeline.html +++ b/api/graphnet.data.pipeline.html @@ -372,11 +372,25 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-data-pipeline--page-root" class="md-nav__link">pipeline</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.pipeline.InSQLitePipeline" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">InSQLitePipeline</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.data.pipeline.InSQLitePipeline" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">InSQLitePipeline</span></code></a> + </li></ul> + </li></ul> </li> @@ -436,7 +450,14 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-data-pipeline--page-root" class="md-nav__link">pipeline</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.data.pipeline.InSQLitePipeline" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">InSQLitePipeline</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -446,8 +467,36 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="pipeline"> -<h1 id="api-graphnet-data-pipeline--page-root">pipeline<a class="headerlink" href="#api-graphnet-data-pipeline--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.data.pipeline"> +<span id="pipeline"></span><h1 id="api-graphnet-data-pipeline--page-root">pipeline<a class="headerlink" href="#api-graphnet-data-pipeline--page-root" title="Link to this heading">¶</a></h1> +<p>Class(es) used for analysis in PISA.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.data.pipeline.InSQLitePipeline"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.data.pipeline.</span></span><span class="sig-name descname"><span class="pre">InSQLitePipeline</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">module_dict</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">device</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">retro_table_name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">outdir</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_workers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pipeline_name</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/data/pipeline.html#InSQLitePipeline"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.data.pipeline.InSQLitePipeline" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">ABC</span></code>, <a class="reference internal" href="graphnet.utilities.logging.html#graphnet.utilities.logging.Logger" title="graphnet.utilities.logging.Logger"><code class="xref py py-class docutils literal notranslate"><span class="pre">Logger</span></code></a></p> +<p>Create a SQLite database for PISA analysis.</p> +<p>The database will contain truth and GNN predictions and, if available, +RETRO reconstructions.</p> +<p>Initialise the pipeline.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>module_dict</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>) – A dictionary with GNN modules from GraphNet. E.g. +{‘energy’: gnn_module_for_energy_regression}</p></li> +<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of input features for the GNN modules.</p></li> +<li><p><strong>truth</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – List of truth for the GNN ModuleList.</p></li> +<li><p><strong>device</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">device</span></code>) – The device used for computation.</p></li> +<li><p><strong>retro_table_name</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'retro'</span></code>) – Name of the retro table for.</p></li> +<li><p><strong>outdir</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – the directory in which the pipeline database will be +stored.</p></li> +<li><p><strong>batch_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">100</span></code>) – Batch size for inference.</p></li> +<li><p><strong>n_workers</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">10</span></code>) – Number of workers used in dataloading.</p></li> +<li><p><strong>pipeline_name</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'pipeline'</span></code>) – Name of the pipeline. If such a pipeline already +exists, an error will be prompted to avoid overwriting.</p></li> +</ul> +</dd> +</dl> +</dd></dl> </section> @@ -497,7 +546,7 @@ <h1 id="api-graphnet-data-pipeline--page-root">pipeline<a class="headerlink" hre </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.sqlite.html b/api/graphnet.data.sqlite.html index 16343c670..cd932166c 100644 --- a/api/graphnet.data.sqlite.html +++ b/api/graphnet.data.sqlite.html @@ -534,7 +534,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.sqlite.sqlite_dataconverter.html b/api/graphnet.data.sqlite.sqlite_dataconverter.html index 302c9e50d..dfbb879df 100644 --- a/api/graphnet.data.sqlite.sqlite_dataconverter.html +++ b/api/graphnet.data.sqlite.sqlite_dataconverter.html @@ -776,7 +776,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.sqlite.sqlite_utilities.html b/api/graphnet.data.sqlite.sqlite_utilities.html index b11e73e50..514fbdacf 100644 --- a/api/graphnet.data.sqlite.sqlite_utilities.html +++ b/api/graphnet.data.sqlite.sqlite_utilities.html @@ -727,7 +727,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.utilities.html b/api/graphnet.data.utilities.html index 35675598e..894827708 100644 --- a/api/graphnet.data.utilities.html +++ b/api/graphnet.data.utilities.html @@ -536,7 +536,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.utilities.parquet_to_sqlite.html b/api/graphnet.data.utilities.parquet_to_sqlite.html index eca0271ac..86d85b9eb 100644 --- a/api/graphnet.data.utilities.parquet_to_sqlite.html +++ b/api/graphnet.data.utilities.parquet_to_sqlite.html @@ -591,7 +591,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.utilities.random.html b/api/graphnet.data.utilities.random.html index e3b1572b6..ee45ec2a4 100644 --- a/api/graphnet.data.utilities.random.html +++ b/api/graphnet.data.utilities.random.html @@ -563,7 +563,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.data.utilities.string_selection_resolver.html b/api/graphnet.data.utilities.string_selection_resolver.html index 2ee151d57..2be85aed2 100644 --- a/api/graphnet.data.utilities.string_selection_resolver.html +++ b/api/graphnet.data.utilities.string_selection_resolver.html @@ -552,7 +552,7 @@ <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><ul class="simple"> -<li><p><strong>dataset</strong> (<em>Dataset</em>) – </p></li> +<li><p><strong>dataset</strong> (<a class="reference internal" href="graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset" title="graphnet.data.dataset.dataset.Dataset"><em>Dataset</em></a>) – </p></li> <li><p><strong>index_column</strong> (<em>str</em>) – </p></li> <li><p><strong>seed</strong> (<em>int</em><em> | </em><em>None</em>) – </p></li> <li><p><strong>use_cache</strong> (<em>bool</em>) – </p></li> @@ -626,7 +626,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.deployment.html b/api/graphnet.deployment.html index 534adda23..a21590d71 100644 --- a/api/graphnet.deployment.html +++ b/api/graphnet.deployment.html @@ -453,7 +453,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.deployment.i3modules.deployer.html b/api/graphnet.deployment.i3modules.deployer.html index 5a1c55118..c29615e41 100644 --- a/api/graphnet.deployment.i3modules.deployer.html +++ b/api/graphnet.deployment.i3modules.deployer.html @@ -456,7 +456,7 @@ <h1 id="api-graphnet-deployment-i3modules-deployer--page-root">deployer<a class= </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.deployment.i3modules.graphnet_module.html b/api/graphnet.deployment.i3modules.graphnet_module.html index 0005fa3ad..ac93cd1ea 100644 --- a/api/graphnet.deployment.i3modules.graphnet_module.html +++ b/api/graphnet.deployment.i3modules.graphnet_module.html @@ -336,11 +336,43 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-deployment-i3modules-graphnet-module--page-root" class="md-nav__link">graphnet_module</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GraphNeTI3Module</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">I3InferenceModule</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">I3PulseCleanerModule</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GraphNeTI3Module</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">I3InferenceModule</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">I3PulseCleanerModule</span></code></a> + + + </li></ul> + </li></ul> </li></ul> @@ -395,7 +427,18 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-deployment-i3modules-graphnet-module--page-root" class="md-nav__link">graphnet_module</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GraphNeTI3Module</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">I3InferenceModule</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">I3PulseCleanerModule</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -405,8 +448,94 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="graphnet-module"> -<h1 id="api-graphnet-deployment-i3modules-graphnet-module--page-root">graphnet_module<a class="headerlink" href="#api-graphnet-deployment-i3modules-graphnet-module--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.deployment.i3modules.graphnet_module"> +<span id="graphnet-module"></span><h1 id="api-graphnet-deployment-i3modules-graphnet-module--page-root">graphnet_module<a class="headerlink" href="#api-graphnet-deployment-i3modules-graphnet-module--page-root" title="Link to this heading">¶</a></h1> +<p>Class(es) for deploying GraphNeT models in icetray as I3Modules.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.deployment.i3modules.graphnet_module.</span></span><span class="sig-name descname"><span class="pre">GraphNeTI3Module</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemap</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemap_extractor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gcd_file</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/deployment/i3modules/graphnet_module.html#GraphNeTI3Module"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">object</span></code></p> +<p>Base I3 Module for GraphNeT.</p> +<p>Contains methods for extracting pulsemaps, producing graphs and writing to +frames.</p> +<p>I3Module Constructor.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>graph_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a>) – An instance of GraphDefinition. E.g. KNNGraph.</p></li> +<li><p><strong>pulsemap</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – the pulse map on which the module functions</p></li> +<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – the features that is used from the pulse map. +E.g. [dom_x, dom_y, dom_z, charge]</p></li> +<li><p><strong>pulsemap_extractor</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<a class="reference internal" href="graphnet.data.extractors.i3featureextractor.html#graphnet.data.extractors.i3featureextractor.I3FeatureExtractor" title="graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"><code class="xref py py-class docutils literal notranslate"><span class="pre">I3FeatureExtractor</span></code></a>], <a class="reference internal" href="graphnet.data.extractors.i3featureextractor.html#graphnet.data.extractors.i3featureextractor.I3FeatureExtractor" title="graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"><code class="xref py py-class docutils literal notranslate"><span class="pre">I3FeatureExtractor</span></code></a>]) – The I3FeatureExtractor used to extract the +pulsemap from the I3Frames</p></li> +<li><p><strong>gcd_file</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – Path to the associated gcd-file.</p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.deployment.i3modules.graphnet_module.I3InferenceModule"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.deployment.i3modules.graphnet_module.</span></span><span class="sig-name descname"><span class="pre">I3InferenceModule</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pulsemap</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemap_extractor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model_config</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">state_dict</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model_name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gcd_file</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_columns</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/deployment/i3modules/graphnet_module.html#I3InferenceModule"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module" title="graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">GraphNeTI3Module</span></code></a></p> +<p>General class for inference on i3 frames.</p> +<p>General class for inference on I3Frames (physics).</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>pulsemap</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – the pulsmap that the model is expecting as input.</p></li> +<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – the features of the pulsemap that the model is expecting.</p></li> +<li><p><strong>pulsemap_extractor</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<a class="reference internal" href="graphnet.data.extractors.i3featureextractor.html#graphnet.data.extractors.i3featureextractor.I3FeatureExtractor" title="graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"><code class="xref py py-class docutils literal notranslate"><span class="pre">I3FeatureExtractor</span></code></a>], <a class="reference internal" href="graphnet.data.extractors.i3featureextractor.html#graphnet.data.extractors.i3featureextractor.I3FeatureExtractor" title="graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"><code class="xref py py-class docutils literal notranslate"><span class="pre">I3FeatureExtractor</span></code></a>]) – The extractor used to extract the pulsemap.</p></li> +<li><p><strong>model_config</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<a class="reference internal" href="graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig" title="graphnet.utilities.config.model_config.ModelConfig"><code class="xref py py-class docutils literal notranslate"><span class="pre">ModelConfig</span></code></a>, <code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – The ModelConfig (or path to it) that summarizes the +model used for inference.</p></li> +<li><p><strong>state_dict</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – Path to state_dict containing the learned weights.</p></li> +<li><p><strong>model_name</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – The name used for the model. Will help define the +named entry in the I3Frame. E.g. “dynedge”.</p></li> +<li><p><strong>gcd_file</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – path to associated gcd file.</p></li> +<li><p><strong>prediction_columns</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – <p>column names for the predictions of the model. +Will help define the named entry in the I3Frame.</p> +<blockquote> +<div><p>E.g. [‘energy_reco’]. Optional.</p> +</div></blockquote> +</p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.deployment.i3modules.graphnet_module.</span></span><span class="sig-name descname"><span class="pre">I3PulseCleanerModule</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pulsemap</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemap_extractor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model_config</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">state_dict</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model_name</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gcd_file</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">threshold</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">discard_empty_events</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_columns</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/deployment/i3modules/graphnet_module.html#I3PulseCleanerModule"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule" title="graphnet.deployment.i3modules.graphnet_module.I3InferenceModule"><code class="xref py py-class docutils literal notranslate"><span class="pre">I3InferenceModule</span></code></a></p> +<p>A specialized module for pulse cleaning.</p> +<p>It is assumed that the model provided has been trained for this.</p> +<p>General class for inference on I3Frames (physics).</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>pulsemap</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – the pulsmap that the model is expecting as input +(the one that is being cleaned).</p></li> +<li><p><strong>features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – the features of the pulsemap that the model is expecting.</p></li> +<li><p><strong>pulsemap_extractor</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<a class="reference internal" href="graphnet.data.extractors.i3featureextractor.html#graphnet.data.extractors.i3featureextractor.I3FeatureExtractor" title="graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"><code class="xref py py-class docutils literal notranslate"><span class="pre">I3FeatureExtractor</span></code></a>], <a class="reference internal" href="graphnet.data.extractors.i3featureextractor.html#graphnet.data.extractors.i3featureextractor.I3FeatureExtractor" title="graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"><code class="xref py py-class docutils literal notranslate"><span class="pre">I3FeatureExtractor</span></code></a>]) – The extractor used to extract the pulsemap.</p></li> +<li><p><strong>model_config</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – The ModelConfig (or path to it) that summarizes the +model used for inference.</p></li> +<li><p><strong>state_dict</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – Path to state_dict containing the learned weights.</p></li> +<li><p><strong>model_name</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – The name used for the model. Will help define the named +entry in the I3Frame. E.g. “dynedge”.</p></li> +<li><p><strong>gcd_file</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – path to associated gcd file.</p></li> +<li><p><strong>threshold</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>, default: <code class="docutils literal notranslate"><span class="pre">0.7</span></code>) – the threshold for being considered a positive case. +E.g., predictions >= threshold will be considered +to be signal, all else noise.</p></li> +<li><p><strong>discard_empty_events</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code>, default: <code class="docutils literal notranslate"><span class="pre">False</span></code>) – When true, this flag will eliminate events +whose cleaned pulse series are empty. Can be used +to speed up processing especially for noise +simulation, since it will not do any writing or +further calculations.</p></li> +<li><p><strong>prediction_columns</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – column names for the predictions of the model. +Will help define the named entry in the I3Frame. +E.g. [‘energy_reco’]. Optional.</p></li> +</ul> +</dd> +</dl> +</dd></dl> </section> @@ -456,7 +585,7 @@ <h1 id="api-graphnet-deployment-i3modules-graphnet-module--page-root">graphnet_m </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.deployment.i3modules.html b/api/graphnet.deployment.i3modules.html index cb57f2df6..975ce8207 100644 --- a/api/graphnet.deployment.i3modules.html +++ b/api/graphnet.deployment.i3modules.html @@ -410,7 +410,12 @@ <h1 id="api-graphnet-deployment-i3modules--page-root">i3modules<a class="headerl <div class="toctree-wrapper compound"> <ul> <li class="toctree-l1"><a class="reference internal" href="graphnet.deployment.i3modules.deployer.html">deployer</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.deployment.i3modules.graphnet_module.html">graphnet_module</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.deployment.i3modules.graphnet_module.html">graphnet_module</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module"><code class="docutils literal notranslate"><span class="pre">GraphNeTI3Module</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule"><code class="docutils literal notranslate"><span class="pre">I3InferenceModule</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule"><code class="docutils literal notranslate"><span class="pre">I3PulseCleanerModule</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -462,7 +467,7 @@ <h1 id="api-graphnet-deployment-i3modules--page-root">i3modules<a class="headerl </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.html b/api/graphnet.html index b04aa430c..5cee1a97c 100644 --- a/api/graphnet.html +++ b/api/graphnet.html @@ -522,7 +522,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.coarsening.html b/api/graphnet.models.coarsening.html index 1dbd5b679..5341555ea 100644 --- a/api/graphnet.models.coarsening.html +++ b/api/graphnet.models.coarsening.html @@ -125,6 +125,7 @@ <script src="../_static/documentation_options.js?v=5929fcd5"></script> <script src="../_static/doctools.js?v=888ff710"></script> <script src="../_static/sphinx_highlight.js?v=dc90522c"></script> + <script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> <link rel="icon" href="../_static/favicon.svg"/> <link rel="author" title="About these documents" href="../about.html" /> <link rel="index" title="Index" href="../genindex.html" /> @@ -365,11 +366,90 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-coarsening--page-root" class="md-nav__link">coarsening</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.unbatch_edge_index" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">unbatch_edge_index()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.Coarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Coarsening</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.Coarsening.reduce_options" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">reduce_options</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.Coarsening.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.AttributeCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AttributeCoarsening</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.DOMCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DOMCoarsening</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.CustomDOMCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">CustomDOMCoarsening</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.DOMAndTimeWindowCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DOMAndTimeWindowCoarsening</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.coarsening.unbatch_edge_index" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">unbatch_edge_index()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.coarsening.Coarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Coarsening</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.coarsening.Coarsening.reduce_options" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">reduce_options</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.coarsening.Coarsening.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.coarsening.AttributeCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AttributeCoarsening</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.coarsening.DOMCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DOMCoarsening</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.coarsening.CustomDOMCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">CustomDOMCoarsening</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.coarsening.DOMAndTimeWindowCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DOMAndTimeWindowCoarsening</span></code></a> + </li></ul> + </li> <li class="md-nav__item"> @@ -436,7 +516,30 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-coarsening--page-root" class="md-nav__link">coarsening</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.unbatch_edge_index" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">unbatch_edge_index()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.Coarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Coarsening</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.Coarsening.reduce_options" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">reduce_options</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.Coarsening.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.AttributeCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AttributeCoarsening</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.DOMCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DOMCoarsening</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.CustomDOMCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">CustomDOMCoarsening</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.coarsening.DOMAndTimeWindowCoarsening" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DOMAndTimeWindowCoarsening</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -446,8 +549,125 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="coarsening"> -<h1 id="api-graphnet-models-coarsening--page-root">coarsening<a class="headerlink" href="#api-graphnet-models-coarsening--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.coarsening"> +<span id="coarsening"></span><h1 id="api-graphnet-models-coarsening--page-root">coarsening<a class="headerlink" href="#api-graphnet-models-coarsening--page-root" title="Link to this heading">¶</a></h1> +<p>Class(es) for coarsening operations (i.e., clustering, or local pooling).</p> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.models.coarsening.unbatch_edge_index"> +<span class="sig-prename descclassname"><span class="pre">graphnet.models.coarsening.</span></span><span class="sig-name descname"><span class="pre">unbatch_edge_index</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">edge_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/coarsening.html#unbatch_edge_index"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.coarsening.unbatch_edge_index" title="Link to this definition">¶</a></dt> +<dd><p>Splits the <code class="xref py py-obj docutils literal notranslate"><span class="pre">edge_index</span></code> according to a <code class="xref py py-obj docutils literal notranslate"><span class="pre">batch</span></code> vector.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>edge_index</strong> (<em>Tensor</em>) – The edge_index tensor. Must be ordered.</p></li> +<li><p><strong>batch</strong> (<em>LongTensor</em>) – The batch vector +<span class="math notranslate nohighlight">\(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\)</span>, which assigns each +node to a specific example. Must be ordered.</p></li> +</ul> +</dd> +<dt class="field-even">Return type<span class="colon">:</span></dt> +<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List[Tensor]</span></code></p> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.coarsening.Coarsening"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.coarsening.</span></span><span class="sig-name descname"><span class="pre">Coarsening</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">reduce</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transfer_attributes</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/coarsening.html#Coarsening"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.coarsening.Coarsening" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p> +<p>Base class for coarsening operations.</p> +<p>Construct <cite>Coarsening</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>reduce</strong> (<em>str</em>) – </p></li> +<li><p><strong>transfer_attributes</strong> (<em>bool</em>) – </p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.coarsening.Coarsening.reduce_options"> +<span class="sig-name descname"><span class="pre">reduce_options</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{'avg':</span> <span class="pre">(<function</span> <span class="pre">avg_pool>,</span> <span class="pre"><function</span> <span class="pre">avg_pool_x>),</span> <span class="pre">'max':</span> <span class="pre">(<function</span> <span class="pre">max_pool>,</span> <span class="pre"><function</span> <span class="pre">max_pool_x>),</span> <span class="pre">'min':</span> <span class="pre">(<function</span> <span class="pre">min_pool>,</span> <span class="pre"><function</span> <span class="pre">min_pool_x>),</span> <span class="pre">'sum':</span> <span class="pre">(<function</span> <span class="pre">sum_pool>,</span> <span class="pre"><function</span> <span class="pre">sum_pool_x>)}</span></em><a class="headerlink" href="#graphnet.models.coarsening.Coarsening.reduce_options" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.coarsening.Coarsening.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/coarsening.html#Coarsening.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.coarsening.Coarsening.forward" title="Link to this definition">¶</a></dt> +<dd><p>Perform coarsening operation.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Batch</span></code>]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>data</strong> (<em>Data</em><em> | </em><em>Batch</em>) – </p> +</dd> +</dl> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.coarsening.AttributeCoarsening"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.coarsening.</span></span><span class="sig-name descname"><span class="pre">AttributeCoarsening</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">attributes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reduce</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transfer_attributes</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/coarsening.html#AttributeCoarsening"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.coarsening.AttributeCoarsening" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.models.coarsening.Coarsening" title="graphnet.models.coarsening.Coarsening"><code class="xref py py-class docutils literal notranslate"><span class="pre">Coarsening</span></code></a></p> +<p>Coarsen pulses based on specified attributes.</p> +<p>Construct <cite>SimpleCoarsening</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>attributes</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>reduce</strong> (<em>str</em>) – </p></li> +<li><p><strong>transfer_attributes</strong> (<em>bool</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.coarsening.DOMCoarsening"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.coarsening.</span></span><span class="sig-name descname"><span class="pre">DOMCoarsening</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">reduce</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transfer_attributes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">keys</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/coarsening.html#DOMCoarsening"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.coarsening.DOMCoarsening" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.models.coarsening.Coarsening" title="graphnet.models.coarsening.Coarsening"><code class="xref py py-class docutils literal notranslate"><span class="pre">Coarsening</span></code></a></p> +<p>Coarsen pulses to DOM-level.</p> +<p>Cluster pulses on the same DOM.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>reduce</strong> (<em>str</em>) – </p></li> +<li><p><strong>transfer_attributes</strong> (<em>bool</em>) – </p></li> +<li><p><strong>keys</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.coarsening.CustomDOMCoarsening"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.coarsening.</span></span><span class="sig-name descname"><span class="pre">CustomDOMCoarsening</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">reduce</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transfer_attributes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">keys</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/coarsening.html#CustomDOMCoarsening"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.coarsening.CustomDOMCoarsening" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.models.coarsening.DOMCoarsening" title="graphnet.models.coarsening.DOMCoarsening"><code class="xref py py-class docutils literal notranslate"><span class="pre">DOMCoarsening</span></code></a></p> +<p>Coarsen pulses to DOM-level with additional attributes.</p> +<p>Cluster pulses on the same DOM.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>reduce</strong> (<em>str</em>) – </p></li> +<li><p><strong>transfer_attributes</strong> (<em>bool</em>) – </p></li> +<li><p><strong>keys</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.coarsening.DOMAndTimeWindowCoarsening"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.coarsening.</span></span><span class="sig-name descname"><span class="pre">DOMAndTimeWindowCoarsening</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="pre">time_window,</span> <span class="pre">reduce,</span> <span class="pre">transfer_attributes,</span> <span class="pre">keys=['dom_x',</span> <span class="pre">'dom_y',</span> <span class="pre">'dom_z',</span> <span class="pre">'rde',</span> <span class="pre">'pmt_area'],</span> <span class="pre">time_key</span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/coarsening.html#DOMAndTimeWindowCoarsening"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.coarsening.DOMAndTimeWindowCoarsening" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.models.coarsening.Coarsening" title="graphnet.models.coarsening.Coarsening"><code class="xref py py-class docutils literal notranslate"><span class="pre">Coarsening</span></code></a></p> +<p>Coarsen pulses to DOM-level, with additional time-window clustering.</p> +<p>Cluster pulses on the same DOM within <cite>time_window</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>time_window</strong> (<em>float</em>) – </p></li> +<li><p><strong>reduce</strong> (<em>str</em>) – </p></li> +<li><p><strong>transfer_attributes</strong> (<em>bool</em>) – </p></li> +<li><p><strong>keys</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>time_key</strong> (<em>str</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> </section> @@ -497,7 +717,7 @@ <h1 id="api-graphnet-models-coarsening--page-root">coarsening<a class="headerlin </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.components.html b/api/graphnet.models.components.html index 831dd56fa..a46e005c6 100644 --- a/api/graphnet.models.components.html +++ b/api/graphnet.models.components.html @@ -460,13 +460,31 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="components"> -<h1 id="api-graphnet-models-components--page-root">components<a class="headerlink" href="#api-graphnet-models-components--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.components"> +<span id="components"></span><h1 id="api-graphnet-models-components--page-root">components<a class="headerlink" href="#api-graphnet-models-components--page-root" title="Link to this heading">¶</a></h1> +<p>Components for constructing models.</p> <p><h2> Submodules </h2></p> <div class="toctree-wrapper compound"> <ul> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.components.layers.html">layers</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.components.pool.html">pool</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.components.layers.html">layers</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.layers.html#graphnet.models.components.layers.DynEdgeConv"><code class="docutils literal notranslate"><span class="pre">DynEdgeConv</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito"><code class="docutils literal notranslate"><span class="pre">EdgeConvTito</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.layers.html#graphnet.models.components.layers.DynTrans"><code class="docutils literal notranslate"><span class="pre">DynTrans</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.components.pool.html">pool</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.min_pool"><code class="docutils literal notranslate"><span class="pre">min_pool()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.min_pool_x"><code class="docutils literal notranslate"><span class="pre">min_pool_x()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool_and_distribute"><code class="docutils literal notranslate"><span class="pre">sum_pool_and_distribute()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.group_by"><code class="docutils literal notranslate"><span class="pre">group_by()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.group_pulses_to_dom"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_dom()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.group_pulses_to_pmt"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_pmt()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool_x"><code class="docutils literal notranslate"><span class="pre">sum_pool_x()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.std_pool_x"><code class="docutils literal notranslate"><span class="pre">std_pool_x()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool"><code class="docutils literal notranslate"><span class="pre">sum_pool()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.components.pool.html#graphnet.models.components.pool.std_pool"><code class="docutils literal notranslate"><span class="pre">std_pool()</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -518,7 +536,7 @@ <h1 id="api-graphnet-models-components--page-root">components<a class="headerlin </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.components.layers.html b/api/graphnet.models.components.layers.html index c9568c1c1..06e348db9 100644 --- a/api/graphnet.models.components.layers.html +++ b/api/graphnet.models.components.layers.html @@ -336,11 +336,94 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-components-layers--page-root" class="md-nav__link">layers</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynEdgeConv" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeConv</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynEdgeConv.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EdgeConvTito</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito.reset_parameters" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">reset_parameters()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito.message" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">message()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynTrans" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynTrans</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynTrans.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.layers.DynEdgeConv" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeConv</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.layers.DynEdgeConv.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.layers.EdgeConvTito" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EdgeConvTito</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.layers.EdgeConvTito.reset_parameters" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">reset_parameters()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.layers.EdgeConvTito.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.layers.EdgeConvTito.message" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">message()</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.layers.DynTrans" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynTrans</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.layers.DynTrans.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -451,7 +534,34 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-components-layers--page-root" class="md-nav__link">layers</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynEdgeConv" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeConv</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynEdgeConv.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EdgeConvTito</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito.reset_parameters" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">reset_parameters()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.EdgeConvTito.message" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">message()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynTrans" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynTrans</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.components.layers.DynTrans.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -461,8 +571,145 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="layers"> -<h1 id="api-graphnet-models-components-layers--page-root">layers<a class="headerlink" href="#api-graphnet-models-components-layers--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.components.layers"> +<span id="layers"></span><h1 id="api-graphnet-models-components-layers--page-root">layers<a class="headerlink" href="#api-graphnet-models-components-layers--page-root" title="Link to this heading">¶</a></h1> +<p>Class(es) implementing layers to be used in <cite>graphnet</cite> models.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.components.layers.DynEdgeConv"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.components.layers.</span></span><span class="sig-name descname"><span class="pre">DynEdgeConv</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggr</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nb_neighbors</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features_subset</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#DynEdgeConv"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.DynEdgeConv" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">EdgeConv</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code></p> +<p>Dynamical edge convolution layer.</p> +<p>Construct <cite>DynEdgeConv</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>nn</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>) – The MLP/torch.Module to be used within the <cite>EdgeConv</cite>.</p></li> +<li><p><strong>aggr</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'max'</span></code>) – Aggregation method to be used with <cite>EdgeConv</cite>.</p></li> +<li><p><strong>nb_neighbors</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">8</span></code>) – Number of neighbours to be clustered after the +<cite>EdgeConv</cite> operation.</p></li> +<li><p><strong>features_subset</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], <code class="xref py py-class docutils literal notranslate"><span class="pre">slice</span></code>, <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Subset of features in <cite>Data.x</cite> that should be used +when dynamically performing the new graph clustering after the +<cite>EdgeConv</cite> operation. Defaults to all features.</p></li> +<li><p><strong>**kwargs</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>) – Additional features to be passed to <cite>EdgeConv</cite>.</p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.components.layers.DynEdgeConv.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">edge_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#DynEdgeConv.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.DynEdgeConv.forward" title="Link to this definition">¶</a></dt> +<dd><p>Forward pass.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>x</strong> (<em>Tensor</em>) – </p></li> +<li><p><strong>edge_index</strong> (<em>Tensor</em><em> | </em><em>SparseTensor</em>) – </p></li> +<li><p><strong>batch</strong> (<em>Tensor</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.components.layers.EdgeConvTito"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.components.layers.</span></span><span class="sig-name descname"><span class="pre">EdgeConvTito</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggr</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#EdgeConvTito"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.EdgeConvTito" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">MessagePassing</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code></p> +<p>Implementation of EdgeConvTito layer used in TITO solution for.</p> +<p>‘IceCube - Neutrinos in Deep’ kaggle competition.</p> +<p>Construct <cite>EdgeConvTito</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>nn</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>) – The MLP/torch.Module to be used within the <cite>EdgeConvTito</cite>.</p></li> +<li><p><strong>aggr</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'max'</span></code>) – Aggregation method to be used with <cite>EdgeConvTito</cite>.</p></li> +<li><p><strong>**kwargs</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>) – Additional features to be passed to <cite>EdgeConvTito</cite>.</p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.components.layers.EdgeConvTito.reset_parameters"> +<span class="sig-name descname"><span class="pre">reset_parameters</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#EdgeConvTito.reset_parameters"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.EdgeConvTito.reset_parameters" title="Link to this definition">¶</a></dt> +<dd><p>Reset all learnable parameters of the module.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.components.layers.EdgeConvTito.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">edge_index</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#EdgeConvTito.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.EdgeConvTito.forward" title="Link to this definition">¶</a></dt> +<dd><p>Forward pass.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>x</strong> (<em>Tensor</em><em> | </em><em>Tuple</em><em>[</em><em>Tensor</em><em>, </em><em>Tensor</em><em>]</em>) – </p></li> +<li><p><strong>edge_index</strong> (<em>Tensor</em><em> | </em><em>SparseTensor</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.components.layers.EdgeConvTito.message"> +<span class="sig-name descname"><span class="pre">message</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x_i</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x_j</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#EdgeConvTito.message"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.EdgeConvTito.message" title="Link to this definition">¶</a></dt> +<dd><p>Edgeconvtito message passing.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>x_i</strong> (<em>Tensor</em>) – </p></li> +<li><p><strong>x_j</strong> (<em>Tensor</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.components.layers.DynTrans"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.components.layers.</span></span><span class="sig-name descname"><span class="pre">DynTrans</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">layer_sizes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aggr</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features_subset</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_head</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#DynTrans"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.DynTrans" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.models.components.layers.EdgeConvTito" title="graphnet.models.components.layers.EdgeConvTito"><code class="xref py py-class docutils literal notranslate"><span class="pre">EdgeConvTito</span></code></a>, <code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code></p> +<p>Implementation of dynTrans1 layer used in TITO solution for.</p> +<p>‘IceCube - Neutrinos in Deep’ kaggle competition.</p> +<p>Construct <cite>DynTrans</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>nn</strong> – The MLP/torch.Module to be used within the <cite>DynTrans</cite>.</p></li> +<li><p><strong>layer_sizes</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – List of layer sizes to be used in <cite>DynTrans</cite>.</p></li> +<li><p><strong>aggr</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'max'</span></code>) – Aggregation method to be used with <cite>DynTrans</cite>.</p></li> +<li><p><strong>features_subset</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Sequence</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], <code class="xref py py-class docutils literal notranslate"><span class="pre">slice</span></code>, <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Subset of features in <cite>Data.x</cite> that should be used +when dynamically performing the new graph clustering after the +<cite>EdgeConv</cite> operation. Defaults to all features.</p></li> +<li><p><strong>n_head</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">8</span></code>) – Number of heads to be used in the multiheadattention models.</p></li> +<li><p><strong>**kwargs</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>) – Additional features to be passed to <cite>DynTrans</cite>.</p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.components.layers.DynTrans.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">edge_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/layers.html#DynTrans.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.layers.DynTrans.forward" title="Link to this definition">¶</a></dt> +<dd><p>Forward pass.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>x</strong> (<em>Tensor</em>) – </p></li> +<li><p><strong>edge_index</strong> (<em>Tensor</em><em> | </em><em>SparseTensor</em>) – </p></li> +<li><p><strong>batch</strong> (<em>Tensor</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -512,7 +759,7 @@ <h1 id="api-graphnet-models-components-layers--page-root">layers<a class="header </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.components.pool.html b/api/graphnet.models.components.pool.html index da3409715..16bcb18ac 100644 --- a/api/graphnet.models.components.pool.html +++ b/api/graphnet.models.components.pool.html @@ -125,6 +125,7 @@ <script src="../_static/documentation_options.js?v=5929fcd5"></script> <script src="../_static/doctools.js?v=888ff710"></script> <script src="../_static/sphinx_highlight.js?v=dc90522c"></script> + <script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> <link rel="icon" href="../_static/favicon.svg"/> <link rel="author" title="About these documents" href="../about.html" /> <link rel="index" title="Index" href="../genindex.html" /> @@ -343,11 +344,106 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-components-pool--page-root" class="md-nav__link">pool</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.min_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">min_pool()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.min_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">min_pool_x()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.sum_pool_and_distribute" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool_and_distribute()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.group_by" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_by()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.group_pulses_to_dom" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_dom()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.group_pulses_to_pmt" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_pmt()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.sum_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool_x()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.std_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">std_pool_x()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.sum_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.std_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">std_pool()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.pool.min_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">min_pool()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.pool.min_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">min_pool_x()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.pool.sum_pool_and_distribute" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool_and_distribute()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.pool.group_by" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_by()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.pool.group_pulses_to_dom" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_dom()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.pool.group_pulses_to_pmt" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_pmt()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.pool.sum_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool_x()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.pool.std_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">std_pool_x()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.pool.sum_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool()</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.components.pool.std_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">std_pool()</span></code></a> + + + </li></ul> + </li></ul> </li> @@ -451,7 +547,32 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-components-pool--page-root" class="md-nav__link">pool</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.min_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">min_pool()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.min_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">min_pool_x()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.sum_pool_and_distribute" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool_and_distribute()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.group_by" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_by()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.group_pulses_to_dom" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_dom()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.group_pulses_to_pmt" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">group_pulses_to_pmt()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.sum_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool_x()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.std_pool_x" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">std_pool_x()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.sum_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">sum_pool()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.components.pool.std_pool" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">std_pool()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -461,8 +582,220 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="pool"> -<h1 id="api-graphnet-models-components-pool--page-root">pool<a class="headerlink" href="#api-graphnet-models-components-pool--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.components.pool"> +<span id="pool"></span><h1 id="api-graphnet-models-components-pool--page-root">pool<a class="headerlink" href="#api-graphnet-models-components-pool--page-root" title="Link to this heading">¶</a></h1> +<p>Functions for performing pooling/clustering/coarsening.</p> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.models.components.pool.min_pool"> +<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">min_pool</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cluster</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">data</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#min_pool"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.min_pool" title="Link to this definition">¶</a></dt> +<dd><p>Perform min-pooling of <cite>Data</cite>.</p> +<p>Like <cite>max_pool, just negating `data.x</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>cluster</strong> (<em>LongTensor</em>) – </p></li> +<li><p><strong>data</strong> (<em>Data</em>) – </p></li> +<li><p><strong>transform</strong> (<em>Any</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.models.components.pool.min_pool_x"> +<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">min_pool_x</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cluster</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">size</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#min_pool_x"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.min_pool_x" title="Link to this definition">¶</a></dt> +<dd><p>Perform min-pooling of <cite>Tensor</cite>.</p> +<p>Like <cite>max_pool_x, just negating `x</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>cluster</strong> (<em>LongTensor</em>) – </p></li> +<li><p><strong>x</strong> (<em>Tensor</em>) – </p></li> +<li><p><strong>batch</strong> (<em>LongTensor</em>) – </p></li> +<li><p><strong>size</strong> (<em>int</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.models.components.pool.sum_pool_and_distribute"> +<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">sum_pool_and_distribute</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">cluster_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#sum_pool_and_distribute"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.sum_pool_and_distribute" title="Link to this definition">¶</a></dt> +<dd><p>Sum-pool values and distribute result to the individual nodes.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>tensor</strong> (<em>Tensor</em>) – </p></li> +<li><p><strong>cluster_index</strong> (<em>LongTensor</em>) – </p></li> +<li><p><strong>batch</strong> (<em>LongTensor</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.models.components.pool.group_by"> +<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">group_by</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">keys</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#group_by"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.group_by" title="Link to this definition">¶</a></dt> +<dd><p>Group nodes in <cite>data</cite> that have identical values of <cite>keys</cite>.</p> +<p>This grouping is done with in each event in case of batching. This allows +for, e.g., assigning the same index to all pulses on the same PMT or DOM in +the same event. This can be used for coarsening graphs, e.g., from pulse- +level to DOM-level by aggregating feature across each group returned by this +method.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">LongTensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>data</strong> (<em>Data</em><em> | </em><em>Batch</em>) – </p></li> +<li><p><strong>keys</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +</ul> +</dd> +</dl> +<p class="rubric">Example</p> +<dl class="simple"> +<dt>Given:</dt><dd><p>data.f1 = [1,1,2,2,2] +data.f2 = [6,7,7,7,8]</p> +</dd> +<dt>Calls:</dt><dd><p>groupby(data, [‘f1’]) -> [0, 0, 1, 1, 1] +groupby(data, [‘f2’]) -> [0, 1, 1, 1, 2] +groupby(data, [‘f1’, ‘f2’]) -> [0, 1, 2, 2, 3]</p> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.models.components.pool.group_pulses_to_dom"> +<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">group_pulses_to_dom</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#group_pulses_to_dom"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.group_pulses_to_dom" title="Link to this definition">¶</a></dt> +<dd><p>Group pulses on the same DOM, using DOM and string number.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.models.components.pool.group_pulses_to_pmt"> +<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">group_pulses_to_pmt</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#group_pulses_to_pmt"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.group_pulses_to_pmt" title="Link to this definition">¶</a></dt> +<dd><p>Group pulses on the same PMT, using PMT, DOM, and string number.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.models.components.pool.sum_pool_x"> +<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">sum_pool_x</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cluster</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">size</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#sum_pool_x"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.sum_pool_x" title="Link to this definition">¶</a></dt> +<dd><p>Sum-pool node features according to the clustering defined in <cite>cluster</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>cluster</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">LongTensor</span></code>) – Cluster vector <span class="math notranslate nohighlight">\(\mathbf{c} \in \{ 0, +\ldots, N - 1 \}^N\)</span>, which assigns each node to a specific cluster.</p></li> +<li><p><strong>x</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>) – Node feature matrix +<span class="math notranslate nohighlight">\(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\)</span>.</p></li> +<li><p><strong>batch</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">LongTensor</span></code>) – Batch vector <span class="math notranslate nohighlight">\(\mathbf{b} \in {\{ 0, \ldots, +B-1\}}^N\)</span>, which assigns each node to a specific example.</p></li> +<li><p><strong>size</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The maximum number of clusters in a single +example. This property is useful to obtain a batch-wise dense +representation, <em>e.g.</em> for applying FC layers, but should only be +used if the size of the maximum number of clusters per example is +known in advance.</p></li> +</ul> +</dd> +<dt class="field-even">Return type<span class="colon">:</span></dt> +<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.models.components.pool.std_pool_x"> +<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">std_pool_x</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cluster</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">size</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#std_pool_x"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.std_pool_x" title="Link to this definition">¶</a></dt> +<dd><p>Std-pool node features according to the clustering defined in <cite>cluster</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>cluster</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">LongTensor</span></code>) – Cluster vector <span class="math notranslate nohighlight">\(\mathbf{c} \in \{ 0, +\ldots, N - 1 \}^N\)</span>, which assigns each node to a specific cluster.</p></li> +<li><p><strong>x</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>) – Node feature matrix +<span class="math notranslate nohighlight">\(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\)</span>.</p></li> +<li><p><strong>batch</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">LongTensor</span></code>) – Batch vector <span class="math notranslate nohighlight">\(\mathbf{b} \in {\{ 0, \ldots, +B-1\}}^N\)</span>, which assigns each node to a specific example.</p></li> +<li><p><strong>size</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The maximum number of clusters in a single +example. This property is useful to obtain a batch-wise dense +representation, <em>e.g.</em> for applying FC layers, but should only be +used if the size of the maximum number of clusters per example is +known in advance.</p></li> +</ul> +</dd> +<dt class="field-even">Return type<span class="colon">:</span></dt> +<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.models.components.pool.sum_pool"> +<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">sum_pool</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cluster</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">data</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#sum_pool"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.sum_pool" title="Link to this definition">¶</a></dt> +<dd><p>Pool and coarsen graph according to the clustering defined in <cite>cluster</cite>.</p> +<p>All nodes within the same cluster will be represented as one node. +Final node features are defined by the <em>sum</em> of features of all nodes +within the same cluster, node positions are averaged and edge indices are +defined to be the union of the edge indices of all nodes within the same +cluster.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>cluster</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">LongTensor</span></code>) – Cluster vector <span class="math notranslate nohighlight">\(\mathbf{c} \in \{ 0, +\ldots, N - 1 \}^N\)</span>, which assigns each node to a specific cluster.</p></li> +<li><p><strong>data</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code>) – Graph data object.</p></li> +<li><p><strong>transform</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – A function/transform that takes in the +coarsened and pooled <code class="xref py py-obj docutils literal notranslate"><span class="pre">torch_geometric.data.Data</span></code> object and +returns a transformed version.</p></li> +</ul> +</dd> +<dt class="field-even">Return type<span class="colon">:</span></dt> +<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code></p> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.models.components.pool.std_pool"> +<span class="sig-prename descclassname"><span class="pre">graphnet.models.components.pool.</span></span><span class="sig-name descname"><span class="pre">std_pool</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cluster</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">data</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/components/pool.html#std_pool"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.components.pool.std_pool" title="Link to this definition">¶</a></dt> +<dd><p>Pool and coarsen graph according to the clustering defined in <cite>cluster</cite>.</p> +<p>All nodes within the same cluster will be represented as one node. +Final node features are defined by the <em>std</em> of features of all nodes +within the same cluster, node positions are averaged and edge indices are +defined to be the union of the edge indices of all nodes within the same +cluster.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>cluster</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">LongTensor</span></code>) – Cluster vector <span class="math notranslate nohighlight">\(\mathbf{c} \in \{ 0, +\ldots, N - 1 \}^N\)</span>, which assigns each node to a specific cluster.</p></li> +<li><p><strong>data</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code>) – Graph data object.</p></li> +<li><p><strong>transform</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – A function/transform that takes in the +coarsened and pooled <code class="xref py py-obj docutils literal notranslate"><span class="pre">torch_geometric.data.Data</span></code> object and +returns a transformed version.</p></li> +</ul> +</dd> +<dt class="field-even">Return type<span class="colon">:</span></dt> +<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code></p> +</dd> +</dl> +</dd></dl> </section> @@ -512,7 +845,7 @@ <h1 id="api-graphnet-models-components-pool--page-root">pool<a class="headerlink </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.detector.detector.html b/api/graphnet.models.detector.detector.html index b8a0e72ee..773954d8d 100644 --- a/api/graphnet.models.detector.detector.html +++ b/api/graphnet.models.detector.detector.html @@ -343,11 +343,45 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-detector-detector--page-root" class="md-nav__link">detector</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.detector.Detector" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Detector</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.detector.Detector.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.detector.detector.Detector.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.detector.detector.Detector" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Detector</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.detector.detector.Detector.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.detector.detector.Detector.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -458,7 +492,20 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-detector-detector--page-root" class="md-nav__link">detector</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.detector.Detector" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Detector</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.detector.Detector.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.detector.detector.Detector.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -468,8 +515,44 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="detector"> -<h1 id="api-graphnet-models-detector-detector--page-root">detector<a class="headerlink" href="#api-graphnet-models-detector-detector--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.detector.detector"> +<span id="detector"></span><h1 id="api-graphnet-models-detector-detector--page-root">detector<a class="headerlink" href="#api-graphnet-models-detector-detector--page-root" title="Link to this heading">¶</a></h1> +<p>Base detector-specific <cite>Model</cite> class(es).</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.detector.detector.Detector"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.detector.detector.</span></span><span class="sig-name descname"><span class="pre">Detector</span></span><a class="reference internal" href="../_modules/graphnet/models/detector/detector.html#Detector"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.detector.Detector" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p> +<p>Base class for all detector-specific read-ins in graphnet.</p> +<p>Construct <cite>Detector</cite>.</p> +<dl class="field-list simple"> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.detector.detector.Detector.feature_map"> +<em class="property"><span class="pre">abstract</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">feature_map</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/detector/detector.html#Detector.feature_map"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.detector.Detector.feature_map" title="Link to this definition">¶</a></dt> +<dd><p>List of features used/assumed by inheriting <cite>Detector</cite> objects.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>]</p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.detector.detector.Detector.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">node_features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_feature_names</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/detector/detector.html#Detector.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.detector.Detector.forward" title="Link to this definition">¶</a></dt> +<dd><p>Pre-process graph <cite>Data</cite> features and build graph adjacency.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>node_features</strong> (<em>tensor</em>) – </p></li> +<li><p><strong>node_feature_names</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -519,7 +602,7 @@ <h1 id="api-graphnet-models-detector-detector--page-root">detector<a class="head </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.detector.html b/api/graphnet.models.detector.html index 0c7904b36..d42b5269f 100644 --- a/api/graphnet.models.detector.html +++ b/api/graphnet.models.detector.html @@ -467,14 +467,27 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="detector"> -<h1 id="api-graphnet-models-detector--page-root">detector<a class="headerlink" href="#api-graphnet-models-detector--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.detector"> +<span id="detector"></span><h1 id="api-graphnet-models-detector--page-root">detector<a class="headerlink" href="#api-graphnet-models-detector--page-root" title="Link to this heading">¶</a></h1> +<p>Detector-specific modules, for data ingestion and standardisation.</p> <p><h2> Submodules </h2></p> <div class="toctree-wrapper compound"> <ul> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.detector.detector.html">detector</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.detector.icecube.html">icecube</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.detector.prometheus.html">prometheus</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.detector.detector.html">detector</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector"><code class="docutils literal notranslate"><span class="pre">Detector</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.detector.icecube.html">icecube</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCube86"><code class="docutils literal notranslate"><span class="pre">IceCube86</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeKaggle"><code class="docutils literal notranslate"><span class="pre">IceCubeKaggle</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeDeepCore"><code class="docutils literal notranslate"><span class="pre">IceCubeDeepCore</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeUpgrade"><code class="docutils literal notranslate"><span class="pre">IceCubeUpgrade</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.detector.prometheus.html">prometheus</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.detector.prometheus.html#graphnet.models.detector.prometheus.Prometheus"><code class="docutils literal notranslate"><span class="pre">Prometheus</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -526,7 +539,7 @@ <h1 id="api-graphnet-models-detector--page-root">detector<a class="headerlink" h </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.detector.icecube.html b/api/graphnet.models.detector.icecube.html index 003c24f60..e37d2e2da 100644 --- a/api/graphnet.models.detector.icecube.html +++ b/api/graphnet.models.detector.icecube.html @@ -350,11 +350,96 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-detector-icecube--page-root" class="md-nav__link">icecube</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCube86" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCube86</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCube86.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeKaggle" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeKaggle</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeKaggle.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeDeepCore" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeDeepCore</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeDeepCore.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeUpgrade" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeUpgrade</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeUpgrade.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.detector.icecube.IceCube86" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCube86</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.detector.icecube.IceCube86.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.detector.icecube.IceCubeKaggle" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeKaggle</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.detector.icecube.IceCubeKaggle.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.detector.icecube.IceCubeDeepCore" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeDeepCore</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.detector.icecube.IceCubeDeepCore.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.detector.icecube.IceCubeUpgrade" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeUpgrade</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.detector.icecube.IceCubeUpgrade.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -458,7 +543,36 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-detector-icecube--page-root" class="md-nav__link">icecube</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCube86" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCube86</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCube86.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeKaggle" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeKaggle</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeKaggle.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeDeepCore" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeDeepCore</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeDeepCore.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeUpgrade" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IceCubeUpgrade</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.icecube.IceCubeUpgrade.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -468,8 +582,85 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="icecube"> -<h1 id="api-graphnet-models-detector-icecube--page-root">icecube<a class="headerlink" href="#api-graphnet-models-detector-icecube--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.detector.icecube"> +<span id="icecube"></span><h1 id="api-graphnet-models-detector-icecube--page-root">icecube<a class="headerlink" href="#api-graphnet-models-detector-icecube--page-root" title="Link to this heading">¶</a></h1> +<p>IceCube-specific <cite>Detector</cite> class(es).</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCube86"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.detector.icecube.</span></span><span class="sig-name descname"><span class="pre">IceCube86</span></span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCube86"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCube86" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector" title="graphnet.models.detector.detector.Detector"><code class="xref py py-class docutils literal notranslate"><span class="pre">Detector</span></code></a></p> +<p><cite>Detector</cite> class for IceCube-86.</p> +<p>Construct <cite>Detector</cite>.</p> +<dl class="field-list simple"> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCube86.feature_map"> +<span class="sig-name descname"><span class="pre">feature_map</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCube86.feature_map"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCube86.feature_map" title="Link to this definition">¶</a></dt> +<dd><p>Map standardization functions to each dimension of input data.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>]</p> +</dd> +</dl> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCubeKaggle"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.detector.icecube.</span></span><span class="sig-name descname"><span class="pre">IceCubeKaggle</span></span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCubeKaggle"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCubeKaggle" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector" title="graphnet.models.detector.detector.Detector"><code class="xref py py-class docutils literal notranslate"><span class="pre">Detector</span></code></a></p> +<p><cite>Detector</cite> class for Kaggle Competition.</p> +<p>Construct <cite>Detector</cite>.</p> +<dl class="field-list simple"> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCubeKaggle.feature_map"> +<span class="sig-name descname"><span class="pre">feature_map</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCubeKaggle.feature_map"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCubeKaggle.feature_map" title="Link to this definition">¶</a></dt> +<dd><p>Map standardization functions to each dimension of input data.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>]</p> +</dd> +</dl> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCubeDeepCore"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.detector.icecube.</span></span><span class="sig-name descname"><span class="pre">IceCubeDeepCore</span></span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCubeDeepCore"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCubeDeepCore" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector" title="graphnet.models.detector.detector.Detector"><code class="xref py py-class docutils literal notranslate"><span class="pre">Detector</span></code></a></p> +<p><cite>Detector</cite> class for IceCube-DeepCore.</p> +<p>Construct <cite>Detector</cite>.</p> +<dl class="field-list simple"> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCubeDeepCore.feature_map"> +<span class="sig-name descname"><span class="pre">feature_map</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCubeDeepCore.feature_map"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCubeDeepCore.feature_map" title="Link to this definition">¶</a></dt> +<dd><p>Map standardization functions to each dimension of input data.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>]</p> +</dd> +</dl> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCubeUpgrade"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.detector.icecube.</span></span><span class="sig-name descname"><span class="pre">IceCubeUpgrade</span></span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCubeUpgrade"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCubeUpgrade" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector" title="graphnet.models.detector.detector.Detector"><code class="xref py py-class docutils literal notranslate"><span class="pre">Detector</span></code></a></p> +<p><cite>Detector</cite> class for IceCube-Upgrade.</p> +<p>Construct <cite>Detector</cite>.</p> +<dl class="field-list simple"> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.detector.icecube.IceCubeUpgrade.feature_map"> +<span class="sig-name descname"><span class="pre">feature_map</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/detector/icecube.html#IceCubeUpgrade.feature_map"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.icecube.IceCubeUpgrade.feature_map" title="Link to this definition">¶</a></dt> +<dd><p>Map standardization functions to each dimension of input data.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>]</p> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -519,7 +710,7 @@ <h1 id="api-graphnet-models-detector-icecube--page-root">icecube<a class="header </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.detector.prometheus.html b/api/graphnet.models.detector.prometheus.html index 9b27ebab9..0c266b36a 100644 --- a/api/graphnet.models.detector.prometheus.html +++ b/api/graphnet.models.detector.prometheus.html @@ -357,11 +357,36 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-detector-prometheus--page-root" class="md-nav__link">prometheus</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.prometheus.Prometheus" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Prometheus</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.prometheus.Prometheus.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.detector.prometheus.Prometheus" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Prometheus</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.detector.prometheus.Prometheus.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + </li></ul> + + </li></ul> + </li></ul> </li> @@ -458,7 +483,18 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-detector-prometheus--page-root" class="md-nav__link">prometheus</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.prometheus.Prometheus" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Prometheus</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.detector.prometheus.Prometheus.feature_map" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">feature_map()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -468,8 +504,28 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="prometheus"> -<h1 id="api-graphnet-models-detector-prometheus--page-root">prometheus<a class="headerlink" href="#api-graphnet-models-detector-prometheus--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.detector.prometheus"> +<span id="prometheus"></span><h1 id="api-graphnet-models-detector-prometheus--page-root">prometheus<a class="headerlink" href="#api-graphnet-models-detector-prometheus--page-root" title="Link to this heading">¶</a></h1> +<p>Prometheus-specific <cite>Detector</cite> class(es).</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.detector.prometheus.Prometheus"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.detector.prometheus.</span></span><span class="sig-name descname"><span class="pre">Prometheus</span></span><a class="reference internal" href="../_modules/graphnet/models/detector/prometheus.html#Prometheus"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.prometheus.Prometheus" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector" title="graphnet.models.detector.detector.Detector"><code class="xref py py-class docutils literal notranslate"><span class="pre">Detector</span></code></a></p> +<p><cite>Detector</cite> class for Prometheus prototype.</p> +<p>Construct <cite>Detector</cite>.</p> +<dl class="field-list simple"> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.detector.prometheus.Prometheus.feature_map"> +<span class="sig-name descname"><span class="pre">feature_map</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/detector/prometheus.html#Prometheus.feature_map"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.detector.prometheus.Prometheus.feature_map" title="Link to this definition">¶</a></dt> +<dd><p>Map standardization functions to each dimension.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>]</p> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -519,7 +575,7 @@ <h1 id="api-graphnet-models-detector-prometheus--page-root">prometheus<a class=" </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.gnn.convnet.html b/api/graphnet.models.gnn.convnet.html index 1d2d39cd2..52d6ba1ea 100644 --- a/api/graphnet.models.gnn.convnet.html +++ b/api/graphnet.models.gnn.convnet.html @@ -350,11 +350,36 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-gnn-convnet--page-root" class="md-nav__link">convnet</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.convnet.ConvNet" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ConvNet</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.convnet.ConvNet.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.gnn.convnet.ConvNet" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ConvNet</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.gnn.convnet.ConvNet.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -472,7 +497,18 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-gnn-convnet--page-root" class="md-nav__link">convnet</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.convnet.ConvNet" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ConvNet</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.convnet.ConvNet.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -482,8 +518,42 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="convnet"> -<h1 id="api-graphnet-models-gnn-convnet--page-root">convnet<a class="headerlink" href="#api-graphnet-models-gnn-convnet--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.gnn.convnet"> +<span id="convnet"></span><h1 id="api-graphnet-models-gnn-convnet--page-root">convnet<a class="headerlink" href="#api-graphnet-models-gnn-convnet--page-root" title="Link to this heading">¶</a></h1> +<p>Implementation of the ConvNet GNN model architecture.</p> +<p>Author: Martin Ha Minh</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.gnn.convnet.ConvNet"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.gnn.convnet.</span></span><span class="sig-name descname"><span class="pre">ConvNet</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nb_inputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nb_outputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nb_intermediate</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dropout_ratio</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/convnet.html#ConvNet"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.convnet.ConvNet" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN" title="graphnet.models.gnn.gnn.GNN"><code class="xref py py-class docutils literal notranslate"><span class="pre">GNN</span></code></a></p> +<p>ConvNet (convolutional network) model.</p> +<p>Construct <cite>ConvNet</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>nb_inputs</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – Number of input features, i.e. dimension of input +layer.</p></li> +<li><p><strong>nb_outputs</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – Number of prediction labels, i.e. dimension of +output layer.</p></li> +<li><p><strong>nb_intermediate</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">128</span></code>) – Number of nodes in intermediate layer(s).</p></li> +<li><p><strong>dropout_ratio</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>, default: <code class="docutils literal notranslate"><span class="pre">0.3</span></code>) – Fraction of nodes to drop.</p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.gnn.convnet.ConvNet.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/convnet.html#ConvNet.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.convnet.ConvNet.forward" title="Link to this definition">¶</a></dt> +<dd><p>Apply learnable forward pass.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -533,7 +603,7 @@ <h1 id="api-graphnet-models-gnn-convnet--page-root">convnet<a class="headerlink" </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.gnn.dynedge.html b/api/graphnet.models.gnn.dynedge.html index 1ae3f0c44..908284418 100644 --- a/api/graphnet.models.gnn.dynedge.html +++ b/api/graphnet.models.gnn.dynedge.html @@ -357,11 +357,36 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-gnn-dynedge--page-root" class="md-nav__link">dynedge</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge.DynEdge" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdge</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge.DynEdge.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.gnn.dynedge.DynEdge" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdge</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.gnn.dynedge.DynEdge.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -472,7 +497,18 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-gnn-dynedge--page-root" class="md-nav__link">dynedge</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge.DynEdge" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdge</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge.DynEdge.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -482,8 +518,64 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="dynedge"> -<h1 id="api-graphnet-models-gnn-dynedge--page-root">dynedge<a class="headerlink" href="#api-graphnet-models-gnn-dynedge--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.gnn.dynedge"> +<span id="dynedge"></span><h1 id="api-graphnet-models-gnn-dynedge--page-root">dynedge<a class="headerlink" href="#api-graphnet-models-gnn-dynedge--page-root" title="Link to this heading">¶</a></h1> +<p>Implementation of the DynEdge GNN model architecture.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.gnn.dynedge.DynEdge"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.gnn.dynedge.</span></span><span class="sig-name descname"><span class="pre">DynEdge</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nb_inputs</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nb_neighbours</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features_subset</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dynedge_layer_sizes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">post_processing_layer_sizes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">readout_layer_sizes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">global_pooling_schemes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">add_global_variables_after_pooling</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/dynedge.html#DynEdge"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.dynedge.DynEdge" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN" title="graphnet.models.gnn.gnn.GNN"><code class="xref py py-class docutils literal notranslate"><span class="pre">GNN</span></code></a></p> +<p>DynEdge (dynamical edge convolutional) model.</p> +<p>Construct <cite>DynEdge</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>nb_inputs</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – Number of input features on each node.</p></li> +<li><p><strong>nb_neighbours</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">8</span></code>) – Number of neighbours to used in the k-nearest +neighbour clustering which is performed after each (dynamical) +edge convolution.</p></li> +<li><p><strong>features_subset</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], <code class="xref py py-class docutils literal notranslate"><span class="pre">slice</span></code>, <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The subset of latent features on each node that +are used as metric dimensions when performing the k-nearest +neighbours clustering. Defaults to [0,1,2].</p></li> +<li><p><strong>dynedge_layer_sizes</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">...</span></code>]]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The layer sizes, or latent feature dimenions, +used in the <cite>DynEdgeConv</cite> layer. Each entry in +<cite>dynedge_layer_sizes</cite> corresponds to a single <cite>DynEdgeConv</cite> +layer; the integers in the corresponding tuple corresponds to +the layer sizes in the multi-layer perceptron (MLP) that is +applied within each <cite>DynEdgeConv</cite> layer. That is, a list of +size-two tuples means that all <cite>DynEdgeConv</cite> layers contain a +two-layer MLP. +Defaults to [(128, 256), (336, 256), (336, 256), (336, 256)].</p></li> +<li><p><strong>post_processing_layer_sizes</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Hidden layer sizes in the MLP +following the skip-concatenation of the outputs of each +<cite>DynEdgeConv</cite> layer. Defaults to [336, 256].</p></li> +<li><p><strong>readout_layer_sizes</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Hidden layer sizes in the MLP following the +post-processing _and_ optional global pooling. As this is the +last layer(s) in the model, the last layer in the read-out +yields the output of the <cite>DynEdge</cite> model. Defaults to [128,].</p></li> +<li><p><strong>global_pooling_schemes</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The list global pooling schemes to use. +Options are: “min”, “max”, “mean”, and “sum”.</p></li> +<li><p><strong>add_global_variables_after_pooling</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code>, default: <code class="docutils literal notranslate"><span class="pre">False</span></code>) – Whether to add global variables +after global pooling. The alternative is to added (distribute) +them to the individual nodes before any convolutional +operations.</p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.gnn.dynedge.DynEdge.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/dynedge.html#DynEdge.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.dynedge.DynEdge.forward" title="Link to this definition">¶</a></dt> +<dd><p>Apply learnable forward pass.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -533,7 +625,7 @@ <h1 id="api-graphnet-models-gnn-dynedge--page-root">dynedge<a class="headerlink" </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.gnn.dynedge_jinst.html b/api/graphnet.models.gnn.dynedge_jinst.html index f46df1144..407d0cb08 100644 --- a/api/graphnet.models.gnn.dynedge_jinst.html +++ b/api/graphnet.models.gnn.dynedge_jinst.html @@ -364,11 +364,36 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-gnn-dynedge-jinst--page-root" class="md-nav__link">dynedge_jinst</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeJINST</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeJINST</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -472,7 +497,18 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-gnn-dynedge-jinst--page-root" class="md-nav__link">dynedge_jinst</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeJINST</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -482,8 +518,39 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="dynedge-jinst"> -<h1 id="api-graphnet-models-gnn-dynedge-jinst--page-root">dynedge_jinst<a class="headerlink" href="#api-graphnet-models-gnn-dynedge-jinst--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.gnn.dynedge_jinst"> +<span id="dynedge-jinst"></span><h1 id="api-graphnet-models-gnn-dynedge-jinst--page-root">dynedge_jinst<a class="headerlink" href="#api-graphnet-models-gnn-dynedge-jinst--page-root" title="Link to this heading">¶</a></h1> +<p>Implementation of the exact DynEdge architecture used in [2209.03042].</p> +<p>Author: Rasmus Oersoe</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.gnn.dynedge_jinst.DynEdgeJINST"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.gnn.dynedge_jinst.</span></span><span class="sig-name descname"><span class="pre">DynEdgeJINST</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nb_inputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">layer_size_scale</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/dynedge_jinst.html#DynEdgeJINST"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN" title="graphnet.models.gnn.gnn.GNN"><code class="xref py py-class docutils literal notranslate"><span class="pre">GNN</span></code></a></p> +<p>DynEdge (dynamical edge convolutional) model used in [2209.03042].</p> +<p>Construct <cite>DynEdgeJINST</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>nb_inputs</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – Number of input features.</p></li> +<li><p><strong>nb_outputs</strong> – Number of output features.</p></li> +<li><p><strong>layer_size_scale</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">4</span></code>) – Integer that scales the size of hidden layers.</p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/dynedge_jinst.html#DynEdgeJINST.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward" title="Link to this definition">¶</a></dt> +<dd><p>Apply learnable forward pass.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -533,7 +600,7 @@ <h1 id="api-graphnet-models-gnn-dynedge-jinst--page-root">dynedge_jinst<a class= </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.gnn.dynedge_kaggle_tito.html b/api/graphnet.models.gnn.dynedge_kaggle_tito.html index 6f259d9bc..a27928192 100644 --- a/api/graphnet.models.gnn.dynedge_kaggle_tito.html +++ b/api/graphnet.models.gnn.dynedge_kaggle_tito.html @@ -371,11 +371,36 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-gnn-dynedge-kaggle-tito--page-root" class="md-nav__link">dynedge_kaggle_tito</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeTITO</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeTITO</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -472,7 +497,18 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-gnn-dynedge-kaggle-tito--page-root" class="md-nav__link">dynedge_kaggle_tito</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DynEdgeTITO</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -482,8 +518,49 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="dynedge-kaggle-tito"> -<h1 id="api-graphnet-models-gnn-dynedge-kaggle-tito--page-root">dynedge_kaggle_tito<a class="headerlink" href="#api-graphnet-models-gnn-dynedge-kaggle-tito--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.gnn.dynedge_kaggle_tito"> +<span id="dynedge-kaggle-tito"></span><h1 id="api-graphnet-models-gnn-dynedge-kaggle-tito--page-root">dynedge_kaggle_tito<a class="headerlink" href="#api-graphnet-models-gnn-dynedge-kaggle-tito--page-root" title="Link to this heading">¶</a></h1> +<p>Implementation of DynEdge architecture used in.</p> +<blockquote> +<div><p>IceCube - Neutrinos in Deep Ice</p> +</div></blockquote> +<p>Reconstruct the direction of neutrinos from the Universe to the South Pole</p> +<p>Kaggle competition.</p> +<p>Solution by TITO.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.gnn.dynedge_kaggle_tito.</span></span><span class="sig-name descname"><span class="pre">DynEdgeTITO</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="pre">nb_inputs,</span> <span class="pre">features_subset(0,</span> <span class="pre">4,</span> <span class="pre">None),</span> <span class="pre">dyntrans_layer_sizes,</span> <span class="pre">global_pooling_schemes=['max']</span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/dynedge_kaggle_tito.html#DynEdgeTITO"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN" title="graphnet.models.gnn.gnn.GNN"><code class="xref py py-class docutils literal notranslate"><span class="pre">GNN</span></code></a></p> +<p>DynEdge (dynamical edge convolutional) model.</p> +<p>Construct <cite>DynEdge</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>nb_inputs</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – Number of input features on each node.</p></li> +<li><p><strong>features_subset</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">slice</span></code>, default: <code class="docutils literal notranslate"><span class="pre">slice(0,</span> <span class="pre">4,</span> <span class="pre">None)</span></code>) – The subset of latent features on each node that +are used as metric dimensions when performing the k-nearest +neighbours clustering. Defaults to [0,1,2,3].</p></li> +<li><p><strong>dyntrans_layer_sizes</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">...</span></code>]]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The layer sizes, or latent feature dimenions, +used in the <cite>DynTrans</cite> layer.</p></li> +<li><p><strong>global_pooling_schemes</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">['max']</span></code>) – The list global pooling schemes to use. +Options are: “min”, “max”, “mean”, and “sum”.</p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/dynedge_kaggle_tito.html#DynEdgeTITO.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward" title="Link to this definition">¶</a></dt> +<dd><p>Apply learnable forward pass.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -533,7 +610,7 @@ <h1 id="api-graphnet-models-gnn-dynedge-kaggle-tito--page-root">dynedge_kaggle_t </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.gnn.gnn.html b/api/graphnet.models.gnn.gnn.html index a59853501..78ee2338d 100644 --- a/api/graphnet.models.gnn.gnn.html +++ b/api/graphnet.models.gnn.gnn.html @@ -378,11 +378,54 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-gnn-gnn--page-root" class="md-nav__link">gnn</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GNN</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN.nb_outputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_outputs</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.gnn.gnn.GNN" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GNN</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.gnn.gnn.GNN.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.gnn.gnn.GNN.nb_outputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_outputs</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.gnn.gnn.GNN.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + + </li></ul> + </li></ul> </li> @@ -472,7 +515,22 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-gnn-gnn--page-root" class="md-nav__link">gnn</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GNN</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN.nb_outputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_outputs</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.gnn.gnn.GNN.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -482,8 +540,47 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="gnn"> -<h1 id="api-graphnet-models-gnn-gnn--page-root">gnn<a class="headerlink" href="#api-graphnet-models-gnn-gnn--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.gnn.gnn"> +<span id="gnn"></span><h1 id="api-graphnet-models-gnn-gnn--page-root">gnn<a class="headerlink" href="#api-graphnet-models-gnn-gnn--page-root" title="Link to this heading">¶</a></h1> +<p>Base GNN-specific <cite>Model</cite> class(es).</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.gnn.gnn.GNN"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.gnn.gnn.</span></span><span class="sig-name descname"><span class="pre">GNN</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nb_inputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nb_outputs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/gnn.html#GNN"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.gnn.GNN" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p> +<p>Base class for all core GNN models in graphnet.</p> +<p>Construct <cite>GNN</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>nb_inputs</strong> (<em>int</em>) – </p></li> +<li><p><strong>nb_outputs</strong> (<em>int</em>) – </p></li> +</ul> +</dd> +</dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.models.gnn.gnn.GNN.nb_inputs"> +<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#graphnet.models.gnn.gnn.GNN.nb_inputs" title="Link to this definition">¶</a></dt> +<dd><p>Return number of input features.</p> +</dd></dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.models.gnn.gnn.GNN.nb_outputs"> +<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">nb_outputs</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#graphnet.models.gnn.gnn.GNN.nb_outputs" title="Link to this definition">¶</a></dt> +<dd><p>Return number of output features.</p> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.gnn.gnn.GNN.forward"> +<em class="property"><span class="pre">abstract</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/gnn/gnn.html#GNN.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.gnn.gnn.GNN.forward" title="Link to this definition">¶</a></dt> +<dd><p>Apply learnable forward pass in model.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -533,7 +630,7 @@ <h1 id="api-graphnet-models-gnn-gnn--page-root">gnn<a class="headerlink" href="# </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.gnn.html b/api/graphnet.models.gnn.html index a98fe7164..be8f47a7f 100644 --- a/api/graphnet.models.gnn.html +++ b/api/graphnet.models.gnn.html @@ -481,16 +481,32 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="gnn"> -<h1 id="api-graphnet-models-gnn--page-root">gnn<a class="headerlink" href="#api-graphnet-models-gnn--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.gnn"> +<span id="gnn"></span><h1 id="api-graphnet-models-gnn--page-root">gnn<a class="headerlink" href="#api-graphnet-models-gnn--page-root" title="Link to this heading">¶</a></h1> +<p>GNN-specific modules, for performing the main learnable operations.</p> <p><h2> Submodules </h2></p> <div class="toctree-wrapper compound"> <ul> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.convnet.html">convnet</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.dynedge.html">dynedge</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.dynedge_jinst.html">dynedge_jinst</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.dynedge_kaggle_tito.html">dynedge_kaggle_tito</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.gnn.html">gnn</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.convnet.html">convnet</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.gnn.convnet.html#graphnet.models.gnn.convnet.ConvNet"><code class="docutils literal notranslate"><span class="pre">ConvNet</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.dynedge.html">dynedge</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.gnn.dynedge.html#graphnet.models.gnn.dynedge.DynEdge"><code class="docutils literal notranslate"><span class="pre">DynEdge</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.dynedge_jinst.html">dynedge_jinst</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.gnn.dynedge_jinst.html#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST"><code class="docutils literal notranslate"><span class="pre">DynEdgeJINST</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.dynedge_kaggle_tito.html">dynedge_kaggle_tito</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.gnn.dynedge_kaggle_tito.html#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO"><code class="docutils literal notranslate"><span class="pre">DynEdgeTITO</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.gnn.gnn.html">gnn</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN"><code class="docutils literal notranslate"><span class="pre">GNN</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -542,7 +558,7 @@ <h1 id="api-graphnet-models-gnn--page-root">gnn<a class="headerlink" href="#api- </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.graphs.edges.edges.html b/api/graphnet.models.graphs.edges.edges.html index ea7739e62..98b04f64a 100644 --- a/api/graphnet.models.graphs.edges.edges.html +++ b/api/graphnet.models.graphs.edges.edges.html @@ -363,11 +363,63 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-graphs-edges-edges--page-root" class="md-nav__link">edges</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.EdgeDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.EdgeDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.KNNEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">KNNEdges</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.RadialEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">RadialEdges</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.EuclideanEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EuclideanEdges</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.graphs.edges.edges.EdgeDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.graphs.edges.edges.EdgeDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.graphs.edges.edges.KNNEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">KNNEdges</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.graphs.edges.edges.RadialEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">RadialEdges</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.graphs.edges.edges.EuclideanEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EuclideanEdges</span></code></a> + + + </li></ul> + </li></ul> </li> @@ -473,7 +525,24 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-graphs-edges-edges--page-root" class="md-nav__link">edges</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.EdgeDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.EdgeDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.KNNEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">KNNEdges</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.RadialEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">RadialEdges</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.graphs.edges.edges.EuclideanEdges" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EuclideanEdges</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -483,8 +552,102 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="edges"> -<h1 id="api-graphnet-models-graphs-edges-edges--page-root">edges<a class="headerlink" href="#api-graphnet-models-graphs-edges-edges--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.graphs.edges.edges"> +<span id="edges"></span><h1 id="api-graphnet-models-graphs-edges-edges--page-root">edges<a class="headerlink" href="#api-graphnet-models-graphs-edges-edges--page-root" title="Link to this heading">¶</a></h1> +<p>Class(es) for building/connecting graphs.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.graphs.edges.edges.EdgeDefinition"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.edges.edges.</span></span><span class="sig-name descname"><span class="pre">EdgeDefinition</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">class_name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">level</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">log_folder</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/edges/edges.html#EdgeDefinition"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.edges.edges.EdgeDefinition" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p> +<p>Base class for graph building.</p> +<p>Construct <cite>Logger</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>name</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>class_name</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>level</strong> (<em>int</em>) – </p></li> +<li><p><strong>log_folder</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.graphs.edges.edges.EdgeDefinition.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">graph</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/edges/edges.html#EdgeDefinition.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.edges.edges.EdgeDefinition.forward" title="Link to this definition">¶</a></dt> +<dd><p>Construct edges based on problem specific implementation of.</p> +<p>´_construct_edges´</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>graph</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code>) – a graph without edges</p> +</dd> +<dt class="field-even">Returns<span class="colon">:</span></dt> +<dd class="field-even"><p>a graph with edges</p> +</dd> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p>graph</p> +</dd> +</dl> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.graphs.edges.edges.KNNEdges"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.edges.edges.</span></span><span class="sig-name descname"><span class="pre">KNNEdges</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nb_nearest_neighbours</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">[0,</span> <span class="pre">1,</span> <span class="pre">2]</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/edges/edges.html#KNNEdges"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.edges.edges.KNNEdges" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.models.graphs.edges.edges.EdgeDefinition" title="graphnet.models.graphs.edges.edges.EdgeDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a></p> +<p>Builds edges from the k-nearest neighbours.</p> +<p>K-NN Edge definition.</p> +<p>Will connect nodes together with their ´nb_nearest_neighbours´ +nearest neighbours in the feature space given by ´columns´.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>nb_nearest_neighbours</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – number of neighbours.</p></li> +<li><p><strong>columns</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">[0,</span> <span class="pre">1,</span> <span class="pre">2]</span></code>) – Node features to use for distance calculation.</p></li> +<li><p><strong>[</strong><strong>0</strong> (<em>Defaults to</em>) – </p></li> +<li><p><strong>1</strong> – </p></li> +<li><p><strong>2</strong><strong>]</strong><strong>.</strong> – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.graphs.edges.edges.RadialEdges"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.edges.edges.</span></span><span class="sig-name descname"><span class="pre">RadialEdges</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">radius</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">[0,</span> <span class="pre">1,</span> <span class="pre">2]</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/edges/edges.html#RadialEdges"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.edges.edges.RadialEdges" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.models.graphs.edges.edges.EdgeDefinition" title="graphnet.models.graphs.edges.edges.EdgeDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a></p> +<p>Builds graph from a sphere of chosen radius centred at each node.</p> +<p>Radial edges.</p> +<p>Connects each node to other nodes that are within a sphere of +radius ´r´ centered at the node. The feature space of ´r´ is defined +by ´columns´</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>radius</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>) – radius of sphere</p></li> +<li><p><strong>columns</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">[0,</span> <span class="pre">1,</span> <span class="pre">2]</span></code>) – columns of the node feature matrix used.</p></li> +<li><p><strong>[</strong><strong>0</strong> (<em>Defaults to</em>) – </p></li> +<li><p><strong>1</strong> – </p></li> +<li><p><strong>2</strong><strong>]</strong><strong>.</strong> – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.graphs.edges.edges.EuclideanEdges"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.edges.edges.</span></span><span class="sig-name descname"><span class="pre">EuclideanEdges</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">sigma</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">threshold</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/edges/edges.html#EuclideanEdges"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.edges.edges.EuclideanEdges" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.models.graphs.edges.edges.EdgeDefinition" title="graphnet.models.graphs.edges.edges.EdgeDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a></p> +<p>Builds edges according to Euclidean distance between nodes.</p> +<p>See <a class="reference external" href="https://arxiv.org/pdf/1809.06166.pdf">https://arxiv.org/pdf/1809.06166.pdf</a>.</p> +<p>Construct <cite>EuclideanEdges</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>sigma</strong> (<em>float</em>) – </p></li> +<li><p><strong>threshold</strong> (<em>float</em>) – </p></li> +<li><p><strong>columns</strong> (<em>List</em><em>[</em><em>int</em><em>]</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> </section> @@ -534,7 +697,7 @@ <h1 id="api-graphnet-models-graphs-edges-edges--page-root">edges<a class="header </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.graphs.edges.html b/api/graphnet.models.graphs.edges.html index e6d65d3e9..e4feb5f18 100644 --- a/api/graphnet.models.graphs.edges.html +++ b/api/graphnet.models.graphs.edges.html @@ -482,12 +482,22 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="edges"> -<h1 id="api-graphnet-models-graphs-edges--page-root">edges<a class="headerlink" href="#api-graphnet-models-graphs-edges--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.graphs.edges"> +<span id="edges"></span><h1 id="api-graphnet-models-graphs-edges--page-root">edges<a class="headerlink" href="#api-graphnet-models-graphs-edges--page-root" title="Link to this heading">¶</a></h1> +<p>Modules for constructing graphs.</p> +<p>´GraphDefinition´ defines the nodes and their features, and contains general +graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes +and their features.</p> <p><h2> Submodules </h2></p> <div class="toctree-wrapper compound"> <ul> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.edges.edges.html">edges</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.edges.edges.html">edges</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EdgeDefinition"><code class="docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.KNNEdges"><code class="docutils literal notranslate"><span class="pre">KNNEdges</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.RadialEdges"><code class="docutils literal notranslate"><span class="pre">RadialEdges</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EuclideanEdges"><code class="docutils literal notranslate"><span class="pre">EuclideanEdges</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -539,7 +549,7 @@ <h1 id="api-graphnet-models-graphs-edges--page-root">edges<a class="headerlink" </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.graphs.graph_definition.html b/api/graphnet.models.graphs.graph_definition.html index 45116549f..8dcade719 100644 --- a/api/graphnet.models.graphs.graph_definition.html +++ b/api/graphnet.models.graphs.graph_definition.html @@ -371,11 +371,36 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-graphs-graph-definition--page-root" class="md-nav__link">graph_definition</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.graphs.graph_definition.GraphDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.graphs.graph_definition.GraphDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.graphs.graph_definition.GraphDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.graphs.graph_definition.GraphDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -465,7 +490,18 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-graphs-graph-definition--page-root" class="md-nav__link">graph_definition</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.graphs.graph_definition.GraphDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.graphs.graph_definition.GraphDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -475,8 +511,59 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="graph-definition"> -<h1 id="api-graphnet-models-graphs-graph-definition--page-root">graph_definition<a class="headerlink" href="#api-graphnet-models-graphs-graph-definition--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.graphs.graph_definition"> +<span id="graph-definition"></span><h1 id="api-graphnet-models-graphs-graph-definition--page-root">graph_definition<a class="headerlink" href="#api-graphnet-models-graphs-graph-definition--page-root" title="Link to this heading">¶</a></h1> +<p>Modules for defining graphs.</p> +<p>These are self-contained graph definitions that hold all the graph-altering +code in graphnet. These modules define what the GNNs sees as input and can be +passed to dataloaders during training and deployment.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.graphs.graph_definition.GraphDefinition"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.graph_definition.</span></span><span class="sig-name descname"><span class="pre">GraphDefinition</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">detector</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">edge_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_feature_names</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/graph_definition.html#GraphDefinition"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.graph_definition.GraphDefinition" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p> +<p>An Abstract class to create graph definitions from.</p> +<p>Construct ´GraphDefinition´. The ´detector´ holds.</p> +<p>´Detector´-specific code. E.g. scaling/standardization and geometry +tables.</p> +<p>´node_definition´ defines the nodes in the graph.</p> +<p>´edge_definition´ defines the connectivity of the nodes in the graph.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>detector</strong> (<a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector" title="graphnet.models.detector.detector.Detector"><code class="xref py py-class docutils literal notranslate"><span class="pre">Detector</span></code></a>) – The corresponding ´Detector´ representing the data.</p></li> +<li><p><strong>node_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition" title="graphnet.models.graphs.nodes.nodes.NodeDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">NodeDefinition</span></code></a>) – Definition of nodes.</p></li> +<li><p><strong>edge_definition</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<a class="reference internal" href="graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EdgeDefinition" title="graphnet.models.graphs.edges.edges.EdgeDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">EdgeDefinition</span></code></a>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Definition of edges. Defaults to None.</p></li> +<li><p><strong>node_feature_names</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Names of node feature columns. Defaults to None</p></li> +<li><p><strong>dtype</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">dtype</span></code>], default: <code class="docutils literal notranslate"><span class="pre">torch.float32</span></code>) – data type used for node features. e.g. ´torch.float´</p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.graphs.graph_definition.GraphDefinition.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">node_features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_feature_names</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_dicts</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">custom_label_functions</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_default_value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">data_path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/graph_definition.html#GraphDefinition.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.graph_definition.GraphDefinition.forward" title="Link to this definition">¶</a></dt> +<dd><p>Construct graph as ´Data´ object.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>node_features</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">ndarray</span></code>) – node features for graph. Shape ´[num_nodes, d]´</p></li> +<li><p><strong>node_feature_names</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – name of each column. Shape ´[,d]´.</p></li> +<li><p><strong>truth_dicts</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>]]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Dictionary containing truth labels.</p></li> +<li><p><strong>custom_label_functions</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">...</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>]]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Custom label functions. See <a class="reference external" href="https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels">https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels</a>.</p></li> +<li><p><strong>loss_weight_column</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of column that holds loss weight. Defaults to None.</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Loss weight associated with event. Defaults to None.</p></li> +<li><p><strong>loss_weight_default_value</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – default value for loss weight. Used in instances where some events have no pre-defined loss weight. Defaults to None.</p></li> +<li><p><strong>data_path</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Path to dataset data files. Defaults to None.</p></li> +</ul> +</dd> +<dt class="field-even">Return type<span class="colon">:</span></dt> +<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code></p> +</dd> +<dt class="field-odd">Returns<span class="colon">:</span></dt> +<dd class="field-odd"><p>graph</p> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -526,7 +613,7 @@ <h1 id="api-graphnet-models-graphs-graph-definition--page-root">graph_definition </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.graphs.graphs.html b/api/graphnet.models.graphs.graphs.html index ca229295b..0d889a033 100644 --- a/api/graphnet.models.graphs.graphs.html +++ b/api/graphnet.models.graphs.graphs.html @@ -378,11 +378,25 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-graphs-graphs--page-root" class="md-nav__link">graphs</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.graphs.graphs.KNNGraph" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">KNNGraph</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.graphs.graphs.KNNGraph" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">KNNGraph</span></code></a> + </li></ul> + </li></ul> </li> @@ -465,7 +479,14 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-graphs-graphs--page-root" class="md-nav__link">graphs</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.graphs.graphs.KNNGraph" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">KNNGraph</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -475,8 +496,31 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="graphs"> -<h1 id="api-graphnet-models-graphs-graphs--page-root">graphs<a class="headerlink" href="#api-graphnet-models-graphs-graphs--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.graphs.graphs"> +<span id="graphs"></span><h1 id="api-graphnet-models-graphs-graphs--page-root">graphs<a class="headerlink" href="#api-graphnet-models-graphs-graphs--page-root" title="Link to this heading">¶</a></h1> +<p>A module containing different graph representations in GraphNeT.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.graphs.graphs.KNNGraph"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.graphs.</span></span><span class="sig-name descname"><span class="pre">KNNGraph</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">detector</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_feature_names</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nb_nearest_neighbours</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">[0,</span> <span class="pre">1,</span> <span class="pre">2]</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/graphs.html#KNNGraph"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.graphs.KNNGraph" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a></p> +<p>A Graph representation where Edges are drawn to nearest neighbours.</p> +<p>Construct k-nn graph representation.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>detector</strong> (<a class="reference internal" href="graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector" title="graphnet.models.detector.detector.Detector"><code class="xref py py-class docutils literal notranslate"><span class="pre">Detector</span></code></a>) – Detector that represents your data.</p></li> +<li><p><strong>node_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition" title="graphnet.models.graphs.nodes.nodes.NodeDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">NodeDefinition</span></code></a>) – Definition of nodes in the graph.</p></li> +<li><p><strong>node_feature_names</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of node features.</p></li> +<li><p><strong>dtype</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">dtype</span></code>], default: <code class="docutils literal notranslate"><span class="pre">torch.float32</span></code>) – data type for node features.</p></li> +<li><p><strong>nb_nearest_neighbours</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">8</span></code>) – Number of edges for each node. Defaults to 8.</p></li> +<li><p><strong>columns</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>], default: <code class="docutils literal notranslate"><span class="pre">[0,</span> <span class="pre">1,</span> <span class="pre">2]</span></code>) – node feature columns used for distance calculation</p></li> +<li><p><strong>[</strong><strong>0</strong> (<em>. Defaults to</em>) – </p></li> +<li><p><strong>1</strong> – </p></li> +<li><p><strong>2</strong><strong>]</strong><strong>.</strong> – </p></li> +</ul> +</dd> +</dl> +</dd></dl> </section> @@ -526,7 +570,7 @@ <h1 id="api-graphnet-models-graphs-graphs--page-root">graphs<a class="headerlink </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.graphs.html b/api/graphnet.models.graphs.html index fef0be82e..fbbc2ff3e 100644 --- a/api/graphnet.models.graphs.html +++ b/api/graphnet.models.graphs.html @@ -474,8 +474,12 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="graphs"> -<h1 id="api-graphnet-models-graphs--page-root">graphs<a class="headerlink" href="#api-graphnet-models-graphs--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.graphs"> +<span id="graphs"></span><h1 id="api-graphnet-models-graphs--page-root">graphs<a class="headerlink" href="#api-graphnet-models-graphs--page-root" title="Link to this heading">¶</a></h1> +<p>Modules for constructing graphs.</p> +<p>´GraphDefinition´ defines the nodes and their features, and contains general +graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes +and their features.</p> <p><h2> Subpackages </h2></p> <div class="toctree-wrapper compound"> <ul> @@ -492,8 +496,14 @@ <h1 id="api-graphnet-models-graphs--page-root">graphs<a class="headerlink" href= <p><h2> Submodules </h2></p> <div class="toctree-wrapper compound"> <ul> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.graph_definition.html">graph_definition</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.graphs.html">graphs</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.graph_definition.html">graph_definition</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition"><code class="docutils literal notranslate"><span class="pre">GraphDefinition</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.graphs.html">graphs</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.graphs.html#graphnet.models.graphs.graphs.KNNGraph"><code class="docutils literal notranslate"><span class="pre">KNNGraph</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -545,7 +555,7 @@ <h1 id="api-graphnet-models-graphs--page-root">graphs<a class="headerlink" href= </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.graphs.nodes.html b/api/graphnet.models.graphs.nodes.html index 1609b8099..2e99388bb 100644 --- a/api/graphnet.models.graphs.nodes.html +++ b/api/graphnet.models.graphs.nodes.html @@ -482,12 +482,20 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="nodes"> -<h1 id="api-graphnet-models-graphs-nodes--page-root">nodes<a class="headerlink" href="#api-graphnet-models-graphs-nodes--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.graphs.nodes"> +<span id="nodes"></span><h1 id="api-graphnet-models-graphs-nodes--page-root">nodes<a class="headerlink" href="#api-graphnet-models-graphs-nodes--page-root" title="Link to this heading">¶</a></h1> +<p>Modules for constructing graphs.</p> +<p>´GraphDefinition´ defines the nodes and their features, and contains general +graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes +and their features.</p> <p><h2> Submodules </h2></p> <div class="toctree-wrapper compound"> <ul> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.nodes.nodes.html">nodes</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.graphs.nodes.nodes.html">nodes</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition"><code class="docutils literal notranslate"><span class="pre">NodeDefinition</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodesAsPulses"><code class="docutils literal notranslate"><span class="pre">NodesAsPulses</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -539,7 +547,7 @@ <h1 id="api-graphnet-models-graphs-nodes--page-root">nodes<a class="headerlink" </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.graphs.nodes.nodes.html b/api/graphnet.models.graphs.nodes.nodes.html index 23b5575bd..a5ac7ba2d 100644 --- a/api/graphnet.models.graphs.nodes.nodes.html +++ b/api/graphnet.models.graphs.nodes.nodes.html @@ -370,11 +370,63 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-graphs-nodes-nodes--page-root" class="md-nav__link">nodes</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">NodeDefinition</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_outputs</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">set_number_of_inputs()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodesAsPulses" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">NodesAsPulses</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">NodeDefinition</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_outputs</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">set_number_of_inputs()</span></code></a> + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.graphs.nodes.nodes.NodesAsPulses" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">NodesAsPulses</span></code></a> + + + </li></ul> + </li></ul> </li> @@ -473,7 +525,24 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-graphs-nodes-nodes--page-root" class="md-nav__link">nodes</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">NodeDefinition</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_outputs</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">set_number_of_inputs()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.graphs.nodes.nodes.NodesAsPulses" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">NodesAsPulses</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -483,8 +552,65 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="nodes"> -<h1 id="api-graphnet-models-graphs-nodes-nodes--page-root">nodes<a class="headerlink" href="#api-graphnet-models-graphs-nodes-nodes--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.graphs.nodes.nodes"> +<span id="nodes"></span><h1 id="api-graphnet-models-graphs-nodes-nodes--page-root">nodes<a class="headerlink" href="#api-graphnet-models-graphs-nodes-nodes--page-root" title="Link to this heading">¶</a></h1> +<p>Class(es) for building/connecting graphs.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.graphs.nodes.nodes.NodeDefinition"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.nodes.nodes.</span></span><span class="sig-name descname"><span class="pre">NodeDefinition</span></span><a class="reference internal" href="../_modules/graphnet/models/graphs/nodes/nodes.html#NodeDefinition"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.nodes.nodes.NodeDefinition" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p> +<p>Base class for graph building.</p> +<p>Construct <cite>Detector</cite>.</p> +<dl class="field-list simple"> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.graphs.nodes.nodes.NodeDefinition.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/nodes/nodes.html#NodeDefinition.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.forward" title="Link to this definition">¶</a></dt> +<dd><p>Construct nodes from raw node features.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>x</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">tensor</span></code>) – standardized node features with shape ´[num_pulses, d]´,</p></li> +<li><p><strong>features.</strong> (<em>where ´d´ is the number</em><em> of </em><em>node</em>) – </p></li> +</ul> +</dd> +<dt class="field-even">Returns<span class="colon">:</span></dt> +<dd class="field-even"><p>a graph without edges</p> +</dd> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p>graph</p> +</dd> +</dl> +</dd></dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs"> +<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">nb_outputs</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs" title="Link to this definition">¶</a></dt> +<dd><p>Return number of output features.</p> +<p>This the default, but may be overridden by specific inheriting classes.</p> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs"> +<span class="sig-name descname"><span class="pre">set_number_of_inputs</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">node_feature_names</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/graphs/nodes/nodes.html#NodeDefinition.set_number_of_inputs"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs" title="Link to this definition">¶</a></dt> +<dd><p>Return number of inputs expected by node definition.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>node_feature_names</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]) – name of each node feature column.</p> +</dd> +<dt class="field-even">Return type<span class="colon">:</span></dt> +<dd class="field-even"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p> +</dd> +</dl> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.graphs.nodes.nodes.NodesAsPulses"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.graphs.nodes.nodes.</span></span><span class="sig-name descname"><span class="pre">NodesAsPulses</span></span><a class="reference internal" href="../_modules/graphnet/models/graphs/nodes/nodes.html#NodesAsPulses"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.graphs.nodes.nodes.NodesAsPulses" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.models.graphs.nodes.nodes.NodeDefinition" title="graphnet.models.graphs.nodes.nodes.NodeDefinition"><code class="xref py py-class docutils literal notranslate"><span class="pre">NodeDefinition</span></code></a></p> +<p>Represent each measured pulse of Cherenkov Radiation as a node.</p> +<p>Construct <cite>Detector</cite>.</p> +<dl class="field-list simple"> +</dl> +</dd></dl> </section> @@ -534,7 +660,7 @@ <h1 id="api-graphnet-models-graphs-nodes-nodes--page-root">nodes<a class="header </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.html b/api/graphnet.models.html index fa38a95a6..d514e1468 100644 --- a/api/graphnet.models.html +++ b/api/graphnet.models.html @@ -445,8 +445,14 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="models"> -<h1 id="api-graphnet-models--page-root">models<a class="headerlink" href="#api-graphnet-models--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models"> +<span id="models"></span><h1 id="api-graphnet-models--page-root">models<a class="headerlink" href="#api-graphnet-models--page-root" title="Link to this heading">¶</a></h1> +<p>Modules for configuring and building models.</p> +<p><cite>graphnet.models</cite> allows for configuring and building complex GNN models using +simple, physics-oriented components. This module provides modular components +subclassing <cite>torch.nn.Module</cite>, meaning that users only need to import a few, +existing, purpose-built components and chain them together to form a complete +GNN</p> <p><h2> Subpackages </h2></p> <div class="toctree-wrapper compound"> <ul> @@ -487,10 +493,29 @@ <h1 id="api-graphnet-models--page-root">models<a class="headerlink" href="#api-g <p><h2> Submodules </h2></p> <div class="toctree-wrapper compound"> <ul> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.coarsening.html">coarsening</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.model.html">model</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.standard_model.html">standard_model</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.utils.html">utils</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.coarsening.html">coarsening</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.coarsening.html#graphnet.models.coarsening.unbatch_edge_index"><code class="docutils literal notranslate"><span class="pre">unbatch_edge_index()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.coarsening.html#graphnet.models.coarsening.Coarsening"><code class="docutils literal notranslate"><span class="pre">Coarsening</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.coarsening.html#graphnet.models.coarsening.AttributeCoarsening"><code class="docutils literal notranslate"><span class="pre">AttributeCoarsening</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.coarsening.html#graphnet.models.coarsening.DOMCoarsening"><code class="docutils literal notranslate"><span class="pre">DOMCoarsening</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.coarsening.html#graphnet.models.coarsening.CustomDOMCoarsening"><code class="docutils literal notranslate"><span class="pre">CustomDOMCoarsening</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.coarsening.html#graphnet.models.coarsening.DOMAndTimeWindowCoarsening"><code class="docutils literal notranslate"><span class="pre">DOMAndTimeWindowCoarsening</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.model.html">model</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model"><code class="docutils literal notranslate"><span class="pre">Model</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.standard_model.html">standard_model</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel"><code class="docutils literal notranslate"><span class="pre">StandardModel</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.utils.html">utils</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.utils.html#graphnet.models.utils.calculate_xyzt_homophily"><code class="docutils literal notranslate"><span class="pre">calculate_xyzt_homophily()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.utils.html#graphnet.models.utils.calculate_distance_matrix"><code class="docutils literal notranslate"><span class="pre">calculate_distance_matrix()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.utils.html#graphnet.models.utils.knn_graph_batch"><code class="docutils literal notranslate"><span class="pre">knn_graph_batch()</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -542,7 +567,7 @@ <h1 id="api-graphnet-models--page-root">models<a class="headerlink" href="#api-g </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.model.html b/api/graphnet.models.model.html index f1bc235f8..286f23f20 100644 --- a/api/graphnet.models.model.html +++ b/api/graphnet.models.model.html @@ -372,11 +372,108 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-model--page-root" class="md-nav__link">model</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.model.Model" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Model</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.fit" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">fit()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.predict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.predict_as_dataframe" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict_as_dataframe()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.save" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.load" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.save_state_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_state_dict()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.load_state_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load_state_dict()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.model.Model" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Model</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.model.Model.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.model.Model.fit" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">fit()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.model.Model.predict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.model.Model.predict_as_dataframe" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict_as_dataframe()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.model.Model.save" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.model.Model.load" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.model.Model.save_state_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_state_dict()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.model.Model.load_state_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load_state_dict()</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.model.Model.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a> + + + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -436,7 +533,34 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-model--page-root" class="md-nav__link">model</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.model.Model" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Model</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.fit" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">fit()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.predict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.predict_as_dataframe" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict_as_dataframe()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.save" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.load" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.save_state_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_state_dict()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.load_state_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load_state_dict()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.model.Model.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -446,8 +570,183 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="model"> -<h1 id="api-graphnet-models-model--page-root">model<a class="headerlink" href="#api-graphnet-models-model--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.model"> +<span id="model"></span><h1 id="api-graphnet-models-model--page-root">model<a class="headerlink" href="#api-graphnet-models-model--page-root" title="Link to this heading">¶</a></h1> +<p>Base class(es) for building models.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.model.Model"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.model.</span></span><span class="sig-name descname"><span class="pre">Model</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">class_name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">level</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">log_folder</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.utilities.logging.html#graphnet.utilities.logging.Logger" title="graphnet.utilities.logging.Logger"><code class="xref py py-class docutils literal notranslate"><span class="pre">Logger</span></code></a>, <a class="reference internal" href="graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable" title="graphnet.utilities.config.configurable.Configurable"><code class="xref py py-class docutils literal notranslate"><span class="pre">Configurable</span></code></a>, <code class="xref py py-class docutils literal notranslate"><span class="pre">LightningModule</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">ABC</span></code></p> +<p>Base class for all models in graphnet.</p> +<p>Construct <cite>Logger</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>name</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>class_name</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>level</strong> (<em>int</em>) – </p></li> +<li><p><strong>log_folder</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.model.Model.forward"> +<em class="property"><span class="pre">abstract</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.forward" title="Link to this definition">¶</a></dt> +<dd><p>Forward pass.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code>]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>x</strong> (<em>Tensor</em><em> | </em><em>Data</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.model.Model.fit"> +<span class="sig-name descname"><span class="pre">fit</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">train_dataloader</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">val_dataloader</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">max_epochs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gpus</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">callbacks</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ckpt_path</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">logger</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">log_every_n_steps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gradient_clip_val</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">distribution_strategy</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">trainer_kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.fit"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.fit" title="Link to this definition">¶</a></dt> +<dd><p>Fit <cite>Model</cite> using <cite>pytorch_lightning.Trainer</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>train_dataloader</strong> (<em>DataLoader</em>) – </p></li> +<li><p><strong>val_dataloader</strong> (<em>DataLoader</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>max_epochs</strong> (<em>int</em>) – </p></li> +<li><p><strong>gpus</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>int</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>callbacks</strong> (<em>List</em><em>[</em><em>Callback</em><em>] </em><em>| </em><em>None</em>) – </p></li> +<li><p><strong>ckpt_path</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>logger</strong> (<em>Logger</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>log_every_n_steps</strong> (<em>int</em>) – </p></li> +<li><p><strong>gradient_clip_val</strong> (<em>float</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>distribution_strategy</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>trainer_kwargs</strong> (<em>Any</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.model.Model.predict"> +<span class="sig-name descname"><span class="pre">predict</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">dataloader</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gpus</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">distribution_strategy</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.predict"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.predict" title="Link to this definition">¶</a></dt> +<dd><p>Return predictions for <cite>dataloader</cite>.</p> +<p>Returns a list of Tensors, one for each model output.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>dataloader</strong> (<em>DataLoader</em>) – </p></li> +<li><p><strong>gpus</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>int</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>distribution_strategy</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.model.Model.predict_as_dataframe"> +<span class="sig-name descname"><span class="pre">predict_as_dataframe</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">dataloader</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_columns</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">additional_attributes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gpus</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">distribution_strategy</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.predict_as_dataframe"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.predict_as_dataframe" title="Link to this definition">¶</a></dt> +<dd><p>Return predictions for <cite>dataloader</cite> as a DataFrame.</p> +<p>Include <cite>additional_attributes</cite> as additional columns in the output +DataFrame.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">DataFrame</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>dataloader</strong> (<em>DataLoader</em>) – </p></li> +<li><p><strong>prediction_columns</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>additional_attributes</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li> +<li><p><strong>gpus</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>int</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>distribution_strategy</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.model.Model.save"> +<span class="sig-name descname"><span class="pre">save</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.save"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.save" title="Link to this definition">¶</a></dt> +<dd><p>Save entire model to <cite>path</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>path</strong> (<em>str</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.model.Model.load"> +<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">load</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.load"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.load" title="Link to this definition">¶</a></dt> +<dd><p>Load entire model from <cite>path</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><a class="reference internal" href="#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>path</strong> (<em>str</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.model.Model.save_state_dict"> +<span class="sig-name descname"><span class="pre">save_state_dict</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.save_state_dict"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.save_state_dict" title="Link to this definition">¶</a></dt> +<dd><p>Save model <cite>state_dict</cite> to <cite>path</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>path</strong> (<em>str</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.model.Model.load_state_dict"> +<span class="sig-name descname"><span class="pre">load_state_dict</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.load_state_dict"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.load_state_dict" title="Link to this definition">¶</a></dt> +<dd><p>Load model <cite>state_dict</cite> from <cite>path</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><a class="reference internal" href="#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>path</strong> (<em>str</em><em> | </em><em>Dict</em>) – </p></li> +<li><p><strong>kargs</strong> (<em>Any</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.model.Model.from_config"> +<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">from_config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">source</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">trust</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">load_modules</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/model.html#Model.from_config"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.model.Model.from_config" title="Link to this definition">¶</a></dt> +<dd><p>Construct <cite>Model</cite> instance from <cite>source</cite> configuration.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>trust</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code>, default: <code class="docutils literal notranslate"><span class="pre">False</span></code>) – Whether to trust the ModelConfig file enough to <cite>eval(…)</cite> +any lambda function expressions contained.</p></li> +<li><p><strong>load_modules</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – List of modules used in the definition of the model +which, as a consequence, need to be loaded into the global +namespace. Defaults to loading <cite>torch</cite>.</p></li> +<li><p><strong>source</strong> (<a class="reference internal" href="graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig" title="graphnet.utilities.config.model_config.ModelConfig"><em>ModelConfig</em></a><em> | </em><em>str</em>) – </p></li> +</ul> +</dd> +<dt class="field-even">Raises<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>ValueError</strong> – If the ModelConfig contains lambda functions but + <cite>trust = False</cite>.</p> +</dd> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><a class="reference internal" href="#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -497,7 +796,7 @@ <h1 id="api-graphnet-models-model--page-root">model<a class="headerlink" href="# </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.standard_model.html b/api/graphnet.models.standard_model.html index da580b529..b886a0725 100644 --- a/api/graphnet.models.standard_model.html +++ b/api/graphnet.models.standard_model.html @@ -379,11 +379,135 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-standard-model--page-root" class="md-nav__link">standard_model</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">StandardModel</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.configure_optimizers" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">configure_optimizers()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.shared_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">shared_step()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.training_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">training_step()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.validation_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">validation_step()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.compute_loss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">compute_loss()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.inference" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">inference()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.train" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">train()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.predict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.predict_as_dataframe" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict_as_dataframe()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.standard_model.StandardModel" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">StandardModel</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.standard_model.StandardModel.target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.standard_model.StandardModel.prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.standard_model.StandardModel.configure_optimizers" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">configure_optimizers()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.standard_model.StandardModel.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.standard_model.StandardModel.shared_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">shared_step()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.standard_model.StandardModel.training_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">training_step()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.standard_model.StandardModel.validation_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">validation_step()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.standard_model.StandardModel.compute_loss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">compute_loss()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.standard_model.StandardModel.inference" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">inference()</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.standard_model.StandardModel.train" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">train()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.standard_model.StandardModel.predict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.standard_model.StandardModel.predict_as_dataframe" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict_as_dataframe()</span></code></a> + + + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -436,7 +560,40 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-standard-model--page-root" class="md-nav__link">standard_model</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">StandardModel</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.configure_optimizers" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">configure_optimizers()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.shared_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">shared_step()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.training_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">training_step()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.validation_step" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">validation_step()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.compute_loss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">compute_loss()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.inference" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">inference()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.train" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">train()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.predict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.standard_model.StandardModel.predict_as_dataframe" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">predict_as_dataframe()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -446,8 +603,193 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="standard-model"> -<h1 id="api-graphnet-models-standard-model--page-root">standard_model<a class="headerlink" href="#api-graphnet-models-standard-model--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.standard_model"> +<span id="standard-model"></span><h1 id="api-graphnet-models-standard-model--page-root">standard_model<a class="headerlink" href="#api-graphnet-models-standard-model--page-root" title="Link to this heading">¶</a></h1> +<p>Standard model class(es).</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.standard_model.</span></span><span class="sig-name descname"><span class="pre">StandardModel</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gnn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tasks</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">optimizer_class=<class</span> <span class="pre">'torch.optim.adam.Adam'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">optimizer_kwargs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scheduler_class</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scheduler_kwargs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scheduler_config</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p> +<p>Main class for standard models in graphnet.</p> +<p>This class chains together the different elements of a complete GNN-based +model (detector read-in, GNN architecture, and task-specific read-outs).</p> +<p>Construct <cite>StandardModel</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>graph_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><em>GraphDefinition</em></a>) – </p></li> +<li><p><strong>gnn</strong> (<a class="reference internal" href="graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN" title="graphnet.models.gnn.gnn.GNN"><em>GNN</em></a>) – </p></li> +<li><p><strong>tasks</strong> (<a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><em>Task</em></a><em> | </em><em>List</em><em>[</em><a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><em>Task</em></a><em>]</em>) – </p></li> +<li><p><strong>optimizer_class</strong> (<em>type</em>) – </p></li> +<li><p><strong>optimizer_kwargs</strong> (<em>Dict</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>scheduler_class</strong> (<em>type</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>scheduler_kwargs</strong> (<em>Dict</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>scheduler_config</strong> (<em>Dict</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.target_labels"> +<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">target_labels</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.target_labels" title="Link to this definition">¶</a></dt> +<dd><p>Return target label.</p> +</dd></dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.prediction_labels"> +<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">prediction_labels</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.prediction_labels" title="Link to this definition">¶</a></dt> +<dd><p>Return prediction labels.</p> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.configure_optimizers"> +<span class="sig-name descname"><span class="pre">configure_optimizers</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.configure_optimizers"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.configure_optimizers" title="Link to this definition">¶</a></dt> +<dd><p>Configure the model’s optimizer(s).</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>]</p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.forward" title="Link to this definition">¶</a></dt> +<dd><p>Forward pass, chaining model components.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code>]]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>data</strong> (<em>Data</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.shared_step"> +<span class="sig-name descname"><span class="pre">shared_step</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">batch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_idx</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.shared_step"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.shared_step" title="Link to this definition">¶</a></dt> +<dd><p>Perform shared step.</p> +<p>Applies the forward pass and the following loss calculation, shared +between the training and validation step.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>batch</strong> (<em>Data</em>) – </p></li> +<li><p><strong>batch_idx</strong> (<em>int</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.training_step"> +<span class="sig-name descname"><span class="pre">training_step</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">train_batch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_idx</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.training_step"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.training_step" title="Link to this definition">¶</a></dt> +<dd><p>Perform training step.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>train_batch</strong> (<em>Data</em>) – </p></li> +<li><p><strong>batch_idx</strong> (<em>int</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.validation_step"> +<span class="sig-name descname"><span class="pre">validation_step</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">val_batch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_idx</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.validation_step"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.validation_step" title="Link to this definition">¶</a></dt> +<dd><p>Perform validation step.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>val_batch</strong> (<em>Data</em>) – </p></li> +<li><p><strong>batch_idx</strong> (<em>int</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.compute_loss"> +<span class="sig-name descname"><span class="pre">compute_loss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">data</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">verbose</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.compute_loss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.compute_loss" title="Link to this definition">¶</a></dt> +<dd><p>Compute and sum losses across tasks.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>preds</strong> (<em>Tensor</em>) – </p></li> +<li><p><strong>data</strong> (<em>Data</em>) – </p></li> +<li><p><strong>verbose</strong> (<em>bool</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.inference"> +<span class="sig-name descname"><span class="pre">inference</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.inference"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.inference" title="Link to this definition">¶</a></dt> +<dd><p>Activate inference mode.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.train"> +<span class="sig-name descname"><span class="pre">train</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">mode</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.train"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.train" title="Link to this definition">¶</a></dt> +<dd><p>Deactivate inference mode.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>mode</strong> (<em>bool</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.predict"> +<span class="sig-name descname"><span class="pre">predict</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">dataloader</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gpus</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">distribution_strategy</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.predict"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.predict" title="Link to this definition">¶</a></dt> +<dd><p>Return predictions for <cite>dataloader</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>dataloader</strong> (<em>DataLoader</em>) – </p></li> +<li><p><strong>gpus</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>int</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>distribution_strategy</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.standard_model.StandardModel.predict_as_dataframe"> +<span class="sig-name descname"><span class="pre">predict_as_dataframe</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">dataloader</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_columns</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">additional_attributes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gpus</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">distribution_strategy</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/standard_model.html#StandardModel.predict_as_dataframe"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.standard_model.StandardModel.predict_as_dataframe" title="Link to this definition">¶</a></dt> +<dd><p>Return predictions for <cite>dataloader</cite> as a DataFrame.</p> +<p>Include <cite>additional_attributes</cite> as additional columns in the output +DataFrame.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">DataFrame</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>dataloader</strong> (<em>DataLoader</em>) – </p></li> +<li><p><strong>prediction_columns</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li> +<li><p><strong>additional_attributes</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li> +<li><p><strong>gpus</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>int</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>distribution_strategy</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -497,7 +839,7 @@ <h1 id="api-graphnet-models-standard-model--page-root">standard_model<a class="h </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.task.classification.html b/api/graphnet.models.task.classification.html index 3fb005f08..50cf9ca15 100644 --- a/api/graphnet.models.task.classification.html +++ b/api/graphnet.models.task.classification.html @@ -364,11 +364,101 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-task-classification--page-root" class="md-nav__link">classification</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.MulticlassClassificationTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">MulticlassClassificationTask</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTask</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTaskLogits</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.classification.MulticlassClassificationTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">MulticlassClassificationTask</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.classification.BinaryClassificationTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTask</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.classification.BinaryClassificationTask.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.classification.BinaryClassificationTask.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTaskLogits</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -458,7 +548,34 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-task-classification--page-root" class="md-nav__link">classification</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.MulticlassClassificationTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">MulticlassClassificationTask</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTask</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTaskLogits</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -468,8 +585,147 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="classification"> -<h1 id="api-graphnet-models-task-classification--page-root">classification<a class="headerlink" href="#api-graphnet-models-task-classification--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.task.classification"> +<span id="classification"></span><h1 id="api-graphnet-models-task-classification--page-root">classification<a class="headerlink" href="#api-graphnet-models-task-classification--page-root" title="Link to this heading">¶</a></h1> +<p>Classification-specific <cite>Model</cite> class(es).</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.classification.MulticlassClassificationTask"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.classification.</span></span><span class="sig-name descname"><span class="pre">MulticlassClassificationTask</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nb_outputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/classification.html#MulticlassClassificationTask"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.classification.MulticlassClassificationTask" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.IdentityTask" title="graphnet.models.task.task.IdentityTask"><code class="xref py py-class docutils literal notranslate"><span class="pre">IdentityTask</span></code></a></p> +<p>General task for classifying any number of classes.</p> +<p>Requires the same number of input features as the number of classes being +predicted. Returns the untransformed latent features, which are interpreted +as the logits for each class being classified.</p> +<p>Construct IdentityTask.</p> +<p>Return the <cite>nb_outputs</cite> as a direct, affine transformation of the last +hidden layer.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>nb_outputs</strong> (<em>int</em>) – </p></li> +<li><p><strong>target_labels</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>Any</em>) – </p></li> +<li><p><strong>args</strong> (<em>Any</em>) – </p></li> +<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTask"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.classification.</span></span><span class="sig-name descname"><span class="pre">BinaryClassificationTask</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/classification.html#BinaryClassificationTask"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTask" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p> +<p>Performs binary classification.</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTask.nb_inputs"> +<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">1</span></em><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTask.nb_inputs" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTask.default_target_labels"> +<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['target']</span></em><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTask.default_target_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels"> +<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['target_pred']</span></em><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTaskLogits"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.classification.</span></span><span class="sig-name descname"><span class="pre">BinaryClassificationTaskLogits</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/classification.html#BinaryClassificationTaskLogits"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTaskLogits" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p> +<p>Performs binary classification form logits.</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs"> +<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">1</span></em><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels"> +<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['target']</span></em><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels"> +<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['target_pred']</span></em><a class="headerlink" href="#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +</dd></dl> </section> @@ -519,7 +775,7 @@ <h1 id="api-graphnet-models-task-classification--page-root">classification<a cla </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.task.html b/api/graphnet.models.task.html index 1df19c8e8..0ed5f54b5 100644 --- a/api/graphnet.models.task.html +++ b/api/graphnet.models.task.html @@ -467,14 +467,38 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="task"> -<h1 id="api-graphnet-models-task--page-root">task<a class="headerlink" href="#api-graphnet-models-task--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.task"> +<span id="task"></span><h1 id="api-graphnet-models-task--page-root">task<a class="headerlink" href="#api-graphnet-models-task--page-root" title="Link to this heading">¶</a></h1> +<p>Physics task-specific modules to be used as model “read-outs”.</p> <p><h2> Submodules </h2></p> <div class="toctree-wrapper compound"> <ul> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.task.classification.html">classification</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.task.reconstruction.html">reconstruction</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.models.task.task.html">task</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.task.classification.html">classification</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.classification.html#graphnet.models.task.classification.MulticlassClassificationTask"><code class="docutils literal notranslate"><span class="pre">MulticlassClassificationTask</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTask"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTask</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTaskLogits"><code class="docutils literal notranslate"><span class="pre">BinaryClassificationTaskLogits</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.task.reconstruction.html">reconstruction</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstructionWithKappa</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstruction"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstruction</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa"><code class="docutils literal notranslate"><span class="pre">DirectionReconstructionWithKappa</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstruction"><code class="docutils literal notranslate"><span class="pre">ZenithReconstruction</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa"><code class="docutils literal notranslate"><span class="pre">ZenithReconstructionWithKappa</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstruction"><code class="docutils literal notranslate"><span class="pre">EnergyReconstruction</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithPower"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithPower</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithUncertainty</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.VertexReconstruction"><code class="docutils literal notranslate"><span class="pre">VertexReconstruction</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.PositionReconstruction"><code class="docutils literal notranslate"><span class="pre">PositionReconstruction</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.TimeReconstruction"><code class="docutils literal notranslate"><span class="pre">TimeReconstruction</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.InelasticityReconstruction"><code class="docutils literal notranslate"><span class="pre">InelasticityReconstruction</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.models.task.task.html">task</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task"><code class="docutils literal notranslate"><span class="pre">Task</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.IdentityTask"><code class="docutils literal notranslate"><span class="pre">IdentityTask</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -526,7 +550,7 @@ <h1 id="api-graphnet-models-task--page-root">task<a class="headerlink" href="#ap </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.task.reconstruction.html b/api/graphnet.models.task.reconstruction.html index 1d48b9eb0..06b772e83 100644 --- a/api/graphnet.models.task.reconstruction.html +++ b/api/graphnet.models.task.reconstruction.html @@ -371,11 +371,472 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-task-reconstruction--page-root" class="md-nav__link">reconstruction</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstructionWithKappa</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstruction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DirectionReconstructionWithKappa</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ZenithReconstruction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ZenithReconstructionWithKappa</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstruction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithPower</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithUncertainty</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VertexReconstruction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">PositionReconstruction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">TimeReconstruction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">InelasticityReconstruction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstructionWithKappa</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.AzimuthReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstruction</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DirectionReconstructionWithKappa</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.ZenithReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ZenithReconstruction</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ZenithReconstructionWithKappa</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.EnergyReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstruction</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithPower</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithUncertainty</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.VertexReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VertexReconstruction</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.PositionReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">PositionReconstruction</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.TimeReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">TimeReconstruction</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.InelasticityReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">InelasticityReconstruction</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -458,7 +919,132 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-task-reconstruction--page-root" class="md-nav__link">reconstruction</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstructionWithKappa</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">AzimuthReconstruction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DirectionReconstructionWithKappa</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ZenithReconstruction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ZenithReconstructionWithKappa</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstruction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithPower</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EnergyReconstructionWithUncertainty</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VertexReconstruction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">PositionReconstruction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">TimeReconstruction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">InelasticityReconstruction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -468,8 +1054,706 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="reconstruction"> -<h1 id="api-graphnet-models-task-reconstruction--page-root">reconstruction<a class="headerlink" href="#api-graphnet-models-task-reconstruction--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.task.reconstruction"> +<span id="reconstruction"></span><h1 id="api-graphnet-models-task-reconstruction--page-root">reconstruction<a class="headerlink" href="#api-graphnet-models-task-reconstruction--page-root" title="Link to this heading">¶</a></h1> +<p>Reconstruction-specific <cite>Model</cite> class(es).</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">AzimuthReconstructionWithKappa</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#AzimuthReconstructionWithKappa"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p> +<p>Reconstructs azimuthal angle and associated kappa (1/var).</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels"> +<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['azimuth']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels"> +<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['azimuth_pred',</span> <span class="pre">'azimuth_kappa']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs"> +<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">2</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstruction"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">AzimuthReconstruction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#AzimuthReconstruction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstruction" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa" title="graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa"><code class="xref py py-class docutils literal notranslate"><span class="pre">AzimuthReconstructionWithKappa</span></code></a></p> +<p>Reconstructs azimuthal angle.</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels"> +<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['azimuth']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels"> +<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['azimuth_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs"> +<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">2</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.DirectionReconstructionWithKappa"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">DirectionReconstructionWithKappa</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#DirectionReconstructionWithKappa"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p> +<p>Reconstructs direction with kappa from the 3D-vMF distribution.</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels"> +<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['direction']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels"> +<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['dir_x_pred',</span> <span class="pre">'dir_y_pred',</span> <span class="pre">'dir_z_pred',</span> <span class="pre">'direction_kappa']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs"> +<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">3</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstruction"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">ZenithReconstruction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#ZenithReconstruction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstruction" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p> +<p>Reconstructs zenith angle.</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels"> +<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['zenith']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels"> +<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['zenith_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs"> +<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">1</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstructionWithKappa"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">ZenithReconstructionWithKappa</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#ZenithReconstructionWithKappa"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.models.task.reconstruction.ZenithReconstruction" title="graphnet.models.task.reconstruction.ZenithReconstruction"><code class="xref py py-class docutils literal notranslate"><span class="pre">ZenithReconstruction</span></code></a></p> +<p>Reconstructs zenith angle and associated kappa (1/var).</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels"> +<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['zenith']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels"> +<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['zenith_pred',</span> <span class="pre">'zenith_kappa']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs"> +<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">2</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstruction"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">EnergyReconstruction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#EnergyReconstruction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstruction" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p> +<p>Reconstructs energy using stable method.</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels"> +<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['energy']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels"> +<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['energy_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs"> +<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">1</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithPower"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">EnergyReconstructionWithPower</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#EnergyReconstructionWithPower"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p> +<p>Reconstructs energy.</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels"> +<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['energy']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels"> +<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['energy_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs"> +<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">1</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">EnergyReconstructionWithUncertainty</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#EnergyReconstructionWithUncertainty"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.models.task.reconstruction.EnergyReconstruction" title="graphnet.models.task.reconstruction.EnergyReconstruction"><code class="xref py py-class docutils literal notranslate"><span class="pre">EnergyReconstruction</span></code></a></p> +<p>Reconstructs energy and associated uncertainty (log(var)).</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels"> +<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['energy']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels"> +<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['energy_pred',</span> <span class="pre">'energy_sigma']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs"> +<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">2</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.VertexReconstruction"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">VertexReconstruction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#VertexReconstruction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.VertexReconstruction" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p> +<p>Reconstructs vertex position and time.</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels"> +<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['vertex']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels"> +<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['position_x_pred',</span> <span class="pre">'position_y_pred',</span> <span class="pre">'position_z_pred',</span> <span class="pre">'interaction_time_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs"> +<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">4</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.PositionReconstruction"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">PositionReconstruction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#PositionReconstruction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.PositionReconstruction" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p> +<p>Reconstructs vertex position.</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels"> +<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['position']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels"> +<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['position_x_pred',</span> <span class="pre">'position_y_pred',</span> <span class="pre">'position_z_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs"> +<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">3</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.TimeReconstruction"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">TimeReconstruction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#TimeReconstruction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.TimeReconstruction" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p> +<p>Reconstructs time.</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels"> +<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['interaction_time']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels"> +<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['interaction_time_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs"> +<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">1</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.InelasticityReconstruction"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.reconstruction.</span></span><span class="sig-name descname"><span class="pre">InelasticityReconstruction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/reconstruction.html#InelasticityReconstruction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.reconstruction.InelasticityReconstruction" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.task.task.html#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p> +<p>Reconstructs interaction inelasticity.</p> +<p>That is, 1-(track energy / hadronic energy).</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels"> +<span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['inelasticity']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels"> +<span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">['inelasticity_pred']</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs"> +<span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">1</span></em><a class="headerlink" href="#graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +</dd></dl> </section> @@ -519,7 +1803,7 @@ <h1 id="api-graphnet-models-task-reconstruction--page-root">reconstruction<a cla </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.task.task.html b/api/graphnet.models.task.task.html index 1d38d3bae..35d5c451e 100644 --- a/api/graphnet.models.task.task.html +++ b/api/graphnet.models.task.task.html @@ -378,11 +378,128 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-task-task--page-root" class="md-nav__link">task</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Task</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.compute_loss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">compute_loss()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.inference" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">inference()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.train_eval" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">train_eval()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IdentityTask</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.task.Task" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Task</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.task.Task.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.task.Task.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.task.Task.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.task.Task.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.task.Task.compute_loss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">compute_loss()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.task.Task.inference" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">inference()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.task.Task.train_eval" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">train_eval()</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.task.IdentityTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IdentityTask</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.task.IdentityTask.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.task.IdentityTask.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.task.task.IdentityTask.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + + </li></ul> + </li></ul> </li> @@ -458,7 +575,40 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-task-task--page-root" class="md-nav__link">task</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Task</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.compute_loss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">compute_loss()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.inference" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">inference()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.Task.train_eval" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">train_eval()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">IdentityTask</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask.default_target_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_target_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask.default_prediction_labels" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">default_prediction_labels</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.task.task.IdentityTask.nb_inputs" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">nb_inputs</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -468,8 +618,154 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="task"> -<h1 id="api-graphnet-models-task-task--page-root">task<a class="headerlink" href="#api-graphnet-models-task-task--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.task.task"> +<span id="task"></span><h1 id="api-graphnet-models-task-task--page-root">task<a class="headerlink" href="#api-graphnet-models-task-task--page-root" title="Link to this heading">¶</a></h1> +<p>Base physics task-specific <cite>Model</cite> class(es).</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.task.Task"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.task.</span></span><span class="sig-name descname"><span class="pre">Task</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">hidden_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_function</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_labels</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_prediction_and_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_inference</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transform_support</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/task.html#Task"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.task.Task" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p> +<p>Base class for all reconstruction and classification tasks.</p> +<p>Construct <cite>Task</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>hidden_size</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.</p></li> +<li><p><strong>loss_function</strong> (<a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a>) – Loss function appropriate to the task.</p></li> +<li><p><strong>target_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the <cite>Data</cite> object in +<cite>.compute_loss(…)</cite>.</p></li> +<li><p><strong>prediction_labels</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to <cite>target_label + _pred</cite>.</p></li> +<li><p><strong>transform_prediction_and_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.</p></li> +<li><p><strong>transform_target</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with <cite>transform_inference</cite> to perform the +inverse transform on the predicted quantity to recover the +physical scale.</p></li> +<li><p><strong>transform_inference</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with <cite>transform_target</cite>.</p></li> +<li><p><strong>transform_support</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +<cite>transform_target</cite> and <cite>transform_inference</cite> in case this is +restricted. By default the invertibility of <cite>transform_target</cite> +is tested on the range [-1e6, 1e6].</p></li> +<li><p><strong>loss_weight</strong> (<code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>], default: <code class="docutils literal notranslate"><span class="pre">None</span></code>) – Name of the attribute in <cite>data</cite> containing per-event +loss weights.</p></li> +</ul> +</dd> +</dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.models.task.task.Task.nb_inputs"> +<em class="property"><span class="pre">abstract</span><span class="w"> </span><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#graphnet.models.task.task.Task.nb_inputs" title="Link to this definition">¶</a></dt> +<dd><p>Return number of inputs assumed by task.</p> +</dd></dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.models.task.task.Task.default_target_labels"> +<em class="property"><span class="pre">abstract</span><span class="w"> </span><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#graphnet.models.task.task.Task.default_target_labels" title="Link to this definition">¶</a></dt> +<dd><p>Return default target labels.</p> +</dd></dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.models.task.task.Task.default_prediction_labels"> +<em class="property"><span class="pre">abstract</span><span class="w"> </span><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#graphnet.models.task.task.Task.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd><p>Return default prediction labels.</p> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.task.task.Task.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/task.html#Task.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.task.Task.forward" title="Link to this definition">¶</a></dt> +<dd><p>Forward pass.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Data</span></code>]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>x</strong> (<em>Tensor</em><em> | </em><em>Data</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.task.task.Task.compute_loss"> +<span class="sig-name descname"><span class="pre">compute_loss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pred</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">data</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/task.html#Task.compute_loss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.task.Task.compute_loss" title="Link to this definition">¶</a></dt> +<dd><p>Compute loss of <cite>pred</cite> wrt.</p> +<p>target labels in <cite>data</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>pred</strong> (<em>Tensor</em><em> | </em><em>Data</em>) – </p></li> +<li><p><strong>data</strong> (<em>Data</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.task.task.Task.inference"> +<span class="sig-name descname"><span class="pre">inference</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/task.html#Task.inference"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.task.Task.inference" title="Link to this definition">¶</a></dt> +<dd><p>Activate inference mode.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.models.task.task.Task.train_eval"> +<span class="sig-name descname"><span class="pre">train_eval</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/task.html#Task.train_eval"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.task.Task.train_eval" title="Link to this definition">¶</a></dt> +<dd><p>Deactivate inference mode.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p> +</dd> +</dl> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.models.task.task.IdentityTask"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.models.task.task.</span></span><span class="sig-name descname"><span class="pre">IdentityTask</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">nb_outputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target_labels</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/task/task.html#IdentityTask"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.task.task.IdentityTask" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.models.task.task.Task" title="graphnet.models.task.task.Task"><code class="xref py py-class docutils literal notranslate"><span class="pre">Task</span></code></a></p> +<p>Identity, or trivial, task.</p> +<p>Construct IdentityTask.</p> +<p>Return the <cite>nb_outputs</cite> as a direct, affine transformation of the last +hidden layer.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>nb_outputs</strong> (<em>int</em>) – </p></li> +<li><p><strong>target_labels</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>Any</em>) – </p></li> +<li><p><strong>args</strong> (<em>Any</em>) – </p></li> +<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li> +</ul> +</dd> +</dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.models.task.task.IdentityTask.default_target_labels"> +<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">default_target_labels</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#graphnet.models.task.task.IdentityTask.default_target_labels" title="Link to this definition">¶</a></dt> +<dd><p>Return default target labels.</p> +</dd></dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.models.task.task.IdentityTask.default_prediction_labels"> +<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">default_prediction_labels</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#graphnet.models.task.task.IdentityTask.default_prediction_labels" title="Link to this definition">¶</a></dt> +<dd><p>Return default prediction labels.</p> +</dd></dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.models.task.task.IdentityTask.nb_inputs"> +<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">nb_inputs</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#graphnet.models.task.task.IdentityTask.nb_inputs" title="Link to this definition">¶</a></dt> +<dd><p>Return number of inputs assumed by task.</p> +</dd></dl> +</dd></dl> </section> @@ -519,7 +815,7 @@ <h1 id="api-graphnet-models-task-task--page-root">task<a class="headerlink" href </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.models.utils.html b/api/graphnet.models.utils.html index 6fa20aae4..475f25794 100644 --- a/api/graphnet.models.utils.html +++ b/api/graphnet.models.utils.html @@ -386,11 +386,43 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-utils--page-root" class="md-nav__link">utils</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.utils.calculate_xyzt_homophily" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">calculate_xyzt_homophily()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.utils.calculate_distance_matrix" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">calculate_distance_matrix()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.utils.knn_graph_batch" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">knn_graph_batch()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.models.utils.calculate_xyzt_homophily" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">calculate_xyzt_homophily()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.utils.calculate_distance_matrix" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">calculate_distance_matrix()</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.models.utils.knn_graph_batch" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">knn_graph_batch()</span></code></a> + + + </li></ul> + </li></ul> </li> @@ -436,7 +468,18 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-models-utils--page-root" class="md-nav__link">utils</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.models.utils.calculate_xyzt_homophily" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">calculate_xyzt_homophily()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.utils.calculate_distance_matrix" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">calculate_distance_matrix()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.models.utils.knn_graph_batch" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">knn_graph_batch()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -446,8 +489,69 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="utils"> -<h1 id="api-graphnet-models-utils--page-root">utils<a class="headerlink" href="#api-graphnet-models-utils--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.models.utils"> +<span id="utils"></span><h1 id="api-graphnet-models-utils--page-root">utils<a class="headerlink" href="#api-graphnet-models-utils--page-root" title="Link to this heading">¶</a></h1> +<p>Utility functions for <cite>graphnet.models</cite>.</p> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.models.utils.calculate_xyzt_homophily"> +<span class="sig-prename descclassname"><span class="pre">graphnet.models.utils.</span></span><span class="sig-name descname"><span class="pre">calculate_xyzt_homophily</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">edge_index</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/utils.html#calculate_xyzt_homophily"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.utils.calculate_xyzt_homophily" title="Link to this definition">¶</a></dt> +<dd><p>Calculate xyzt-homophily from a batch of graphs.</p> +<p>Homophily is a graph scalar quantity that measures the likeness of +variables in nodes. Notice that this calculator assumes a special order of +input features in x.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>]</p> +</dd> +<dt class="field-even">Returns<span class="colon">:</span></dt> +<dd class="field-even"><p>Tuple, each element with shape [batch_size,1].</p> +</dd> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>x</strong> (<em>Tensor</em>) – </p></li> +<li><p><strong>edge_index</strong> (<em>LongTensor</em>) – </p></li> +<li><p><strong>batch</strong> (<em>Batch</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.models.utils.calculate_distance_matrix"> +<span class="sig-prename descclassname"><span class="pre">graphnet.models.utils.</span></span><span class="sig-name descname"><span class="pre">calculate_distance_matrix</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">xyz_coords</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/utils.html#calculate_distance_matrix"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.utils.calculate_distance_matrix" title="Link to this definition">¶</a></dt> +<dd><p>Calculate the matrix of pairwise distances between pulses.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>xyz_coords</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>) – (x,y,z)-coordinates of pulses, of shape [nb_doms, 3].</p> +</dd> +<dt class="field-even">Return type<span class="colon">:</span></dt> +<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-odd">Returns<span class="colon">:</span></dt> +<dd class="field-odd"><p>Matrix of pairwise distances, of shape [nb_doms, nb_doms]</p> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.models.utils.knn_graph_batch"> +<span class="sig-prename descclassname"><span class="pre">graphnet.models.utils.</span></span><span class="sig-name descname"><span class="pre">knn_graph_batch</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">batch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">k</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">columns</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/models/utils.html#knn_graph_batch"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.models.utils.knn_graph_batch" title="Link to this definition">¶</a></dt> +<dd><p>Calculate k-nearest-neighbours with individual k for each batch event.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>batch</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Batch</span></code>) – Batch of events.</p></li> +<li><p><strong>k</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]) – A list of k’s.</p></li> +<li><p><strong>columns</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]) – The columns of Data.x used for computing the distances. E.g., +Data.x[:,[0,1,2]]</p></li> +</ul> +</dd> +<dt class="field-even">Return type<span class="colon">:</span></dt> +<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Batch</span></code></p> +</dd> +<dt class="field-odd">Returns<span class="colon">:</span></dt> +<dd class="field-odd"><p>Returns the same batch of events, but with updated edges.</p> +</dd> +</dl> +</dd></dl> </section> @@ -497,7 +601,7 @@ <h1 id="api-graphnet-models-utils--page-root">utils<a class="headerlink" href="# </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.pisa.fitting.html b/api/graphnet.pisa.fitting.html index 71412f39f..8bc401462 100644 --- a/api/graphnet.pisa.fitting.html +++ b/api/graphnet.pisa.fitting.html @@ -659,7 +659,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.pisa.html b/api/graphnet.pisa.html index 610103161..77788eee8 100644 --- a/api/graphnet.pisa.html +++ b/api/graphnet.pisa.html @@ -465,7 +465,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.pisa.plotting.html b/api/graphnet.pisa.plotting.html index 07ab22cb9..fd463920f 100644 --- a/api/graphnet.pisa.plotting.html +++ b/api/graphnet.pisa.plotting.html @@ -573,7 +573,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.training.callbacks.html b/api/graphnet.training.callbacks.html index 62f336aab..2d3ffa1f3 100644 --- a/api/graphnet.training.callbacks.html +++ b/api/graphnet.training.callbacks.html @@ -344,11 +344,110 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-training-callbacks--page-root" class="md-nav__link">callbacks</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.PiecewiseLinearLR" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">PiecewiseLinearLR</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.PiecewiseLinearLR.get_lr" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_lr()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ProgressBar</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_validation_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_validation_tqdm()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_predict_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_predict_tqdm()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_test_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_test_tqdm()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_train_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_train_tqdm()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.get_metrics" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_metrics()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_start" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">on_train_epoch_start()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_end" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">on_train_epoch_end()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.training.callbacks.PiecewiseLinearLR" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">PiecewiseLinearLR</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.training.callbacks.PiecewiseLinearLR.get_lr" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_lr()</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.callbacks.ProgressBar" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ProgressBar</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.training.callbacks.ProgressBar.init_validation_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_validation_tqdm()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.callbacks.ProgressBar.init_predict_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_predict_tqdm()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.callbacks.ProgressBar.init_test_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_test_tqdm()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.callbacks.ProgressBar.init_train_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_train_tqdm()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.callbacks.ProgressBar.get_metrics" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_metrics()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_start" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">on_train_epoch_start()</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_end" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">on_train_epoch_end()</span></code></a> + + + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -408,7 +507,36 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-training-callbacks--page-root" class="md-nav__link">callbacks</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.PiecewiseLinearLR" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">PiecewiseLinearLR</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.PiecewiseLinearLR.get_lr" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_lr()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ProgressBar</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_validation_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_validation_tqdm()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_predict_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_predict_tqdm()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_test_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_test_tqdm()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.init_train_tqdm" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">init_train_tqdm()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.get_metrics" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_metrics()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_start" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">on_train_epoch_start()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_end" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">on_train_epoch_end()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -418,8 +546,151 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="callbacks"> -<h1 id="api-graphnet-training-callbacks--page-root">callbacks<a class="headerlink" href="#api-graphnet-training-callbacks--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.training.callbacks"> +<span id="callbacks"></span><h1 id="api-graphnet-training-callbacks--page-root">callbacks<a class="headerlink" href="#api-graphnet-training-callbacks--page-root" title="Link to this heading">¶</a></h1> +<p>Callback class(es) for using during model training.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.callbacks.PiecewiseLinearLR"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.callbacks.</span></span><span class="sig-name descname"><span class="pre">PiecewiseLinearLR</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">optimizer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">milestones</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">factors</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">last_epoch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">verbose</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#PiecewiseLinearLR"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.PiecewiseLinearLR" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">_LRScheduler</span></code></p> +<p>Interpolate learning rate linearly between milestones.</p> +<p>Construct <cite>PiecewiseLinearLR</cite>.</p> +<p>For each milestone, denoting a specified number of steps, a factor +multiplying the base learning rate is specified. For steps between two +milestones, the learning rate is interpolated linearly between the two +closest milestones. For steps before the first milestone, the factor +for the first milestone is used; vice versa for steps after the last +milestone.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>optimizer</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Optimizer</span></code>) – Wrapped optimizer.</p></li> +<li><p><strong>milestones</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>]) – List of step indices. Must be increasing.</p></li> +<li><p><strong>factors</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>]) – List of multiplicative factors. Must be same length as +<cite>milestones</cite>.</p></li> +<li><p><strong>last_epoch</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code>, default: <code class="docutils literal notranslate"><span class="pre">-1</span></code>) – The index of the last epoch.</p></li> +<li><p><strong>verbose</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code>, default: <code class="docutils literal notranslate"><span class="pre">False</span></code>) – If <code class="docutils literal notranslate"><span class="pre">True</span></code>, prints a message to stdout for each update.</p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.training.callbacks.PiecewiseLinearLR.get_lr"> +<span class="sig-name descname"><span class="pre">get_lr</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#PiecewiseLinearLR.get_lr"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.PiecewiseLinearLR.get_lr" title="Link to this definition">¶</a></dt> +<dd><p>Get effective learning rate(s) for each optimizer.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code>]</p> +</dd> +</dl> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.callbacks.</span></span><span class="sig-name descname"><span class="pre">ProgressBar</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">refresh_rate</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">process_position</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">TQDMProgressBar</span></code></p> +<p>Custom progress bar for graphnet.</p> +<p>Customises the default progress in pytorch-lightning.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>refresh_rate</strong> (<em>int</em>) – </p></li> +<li><p><strong>process_position</strong> (<em>int</em>) – </p></li> +</ul> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar.init_validation_tqdm"> +<span class="sig-name descname"><span class="pre">init_validation_tqdm</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar.init_validation_tqdm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar.init_validation_tqdm" title="Link to this definition">¶</a></dt> +<dd><p>Override for customisation.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Bar</span></code></p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar.init_predict_tqdm"> +<span class="sig-name descname"><span class="pre">init_predict_tqdm</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar.init_predict_tqdm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar.init_predict_tqdm" title="Link to this definition">¶</a></dt> +<dd><p>Override for customisation.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Bar</span></code></p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar.init_test_tqdm"> +<span class="sig-name descname"><span class="pre">init_test_tqdm</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar.init_test_tqdm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar.init_test_tqdm" title="Link to this definition">¶</a></dt> +<dd><p>Override for customisation.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Bar</span></code></p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar.init_train_tqdm"> +<span class="sig-name descname"><span class="pre">init_train_tqdm</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar.init_train_tqdm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar.init_train_tqdm" title="Link to this definition">¶</a></dt> +<dd><p>Override for customisation.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Bar</span></code></p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar.get_metrics"> +<span class="sig-name descname"><span class="pre">get_metrics</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">trainer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar.get_metrics"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar.get_metrics" title="Link to this definition">¶</a></dt> +<dd><p>Override to not show the version number in the logging.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>trainer</strong> (<em>Trainer</em>) – </p></li> +<li><p><strong>model</strong> (<em>LightningModule</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar.on_train_epoch_start"> +<span class="sig-name descname"><span class="pre">on_train_epoch_start</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">trainer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar.on_train_epoch_start"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_start" title="Link to this definition">¶</a></dt> +<dd><p>Print the results of the previous epoch on a separate line.</p> +<p>This allows the user to see the losses/metrics for previous epochs +while the current is training. The default behaviour in pytorch- +lightning is to overwrite the progress bar from previous epochs.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>trainer</strong> (<em>Trainer</em>) – </p></li> +<li><p><strong>model</strong> (<em>LightningModule</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.training.callbacks.ProgressBar.on_train_epoch_end"> +<span class="sig-name descname"><span class="pre">on_train_epoch_end</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">trainer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/callbacks.html#ProgressBar.on_train_epoch_end"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.callbacks.ProgressBar.on_train_epoch_end" title="Link to this definition">¶</a></dt> +<dd><p>Log the final progress bar for the epoch to file.</p> +<p>Don’t duplciate to stdout.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>trainer</strong> (<em>Trainer</em>) – </p></li> +<li><p><strong>model</strong> (<em>LightningModule</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -469,7 +740,7 @@ <h1 id="api-graphnet-training-callbacks--page-root">callbacks<a class="headerlin </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.training.html b/api/graphnet.training.html index 3a7e0dc99..c9ac2b308 100644 --- a/api/graphnet.training.html +++ b/api/graphnet.training.html @@ -424,10 +424,38 @@ <p><h2> Submodules </h2></p> <div class="toctree-wrapper compound"> <ul> -<li class="toctree-l1"><a class="reference internal" href="graphnet.training.callbacks.html">callbacks</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.training.labels.html">labels</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.training.loss_functions.html">loss_functions</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.training.utils.html">utils</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.training.callbacks.html">callbacks</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.callbacks.html#graphnet.training.callbacks.PiecewiseLinearLR"><code class="docutils literal notranslate"><span class="pre">PiecewiseLinearLR</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar"><code class="docutils literal notranslate"><span class="pre">ProgressBar</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.training.labels.html">labels</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.labels.html#graphnet.training.labels.Label"><code class="docutils literal notranslate"><span class="pre">Label</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.labels.html#graphnet.training.labels.Direction"><code class="docutils literal notranslate"><span class="pre">Direction</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.training.loss_functions.html">loss_functions</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction"><code class="docutils literal notranslate"><span class="pre">LossFunction</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.MSELoss"><code class="docutils literal notranslate"><span class="pre">MSELoss</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.RMSELoss"><code class="docutils literal notranslate"><span class="pre">RMSELoss</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCoshLoss"><code class="docutils literal notranslate"><span class="pre">LogCoshLoss</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.CrossEntropyLoss"><code class="docutils literal notranslate"><span class="pre">CrossEntropyLoss</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.BinaryCrossEntropyLoss"><code class="docutils literal notranslate"><span class="pre">BinaryCrossEntropyLoss</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCMK"><code class="docutils literal notranslate"><span class="pre">LogCMK</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss"><code class="docutils literal notranslate"><span class="pre">VonMisesFisherLoss</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisher2DLoss"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher2DLoss</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.EuclideanDistanceLoss"><code class="docutils literal notranslate"><span class="pre">EuclideanDistanceLoss</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisher3DLoss"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher3DLoss</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.training.utils.html">utils</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.utils.html#graphnet.training.utils.collate_fn"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.utils.html#graphnet.training.utils.make_dataloader"><code class="docutils literal notranslate"><span class="pre">make_dataloader()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.utils.html#graphnet.training.utils.make_train_validation_dataloader"><code class="docutils literal notranslate"><span class="pre">make_train_validation_dataloader()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.utils.html#graphnet.training.utils.get_predictions"><code class="docutils literal notranslate"><span class="pre">get_predictions()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.training.utils.html#graphnet.training.utils.save_results"><code class="docutils literal notranslate"><span class="pre">save_results()</span></code></a></li> +</ul> +</li> <li class="toctree-l1"><a class="reference internal" href="graphnet.training.weight_fitting.html">weight_fitting</a><ul> <li class="toctree-l2"><a class="reference internal" href="graphnet.training.weight_fitting.html#graphnet.training.weight_fitting.WeightFitter"><code class="docutils literal notranslate"><span class="pre">WeightFitter</span></code></a></li> <li class="toctree-l2"><a class="reference internal" href="graphnet.training.weight_fitting.html#graphnet.training.weight_fitting.Uniform"><code class="docutils literal notranslate"><span class="pre">Uniform</span></code></a></li> @@ -485,7 +513,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.training.labels.html b/api/graphnet.training.labels.html index cb464052f..6514d41a8 100644 --- a/api/graphnet.training.labels.html +++ b/api/graphnet.training.labels.html @@ -351,11 +351,45 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-training-labels--page-root" class="md-nav__link">labels</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.labels.Label" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Label</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.labels.Label.key" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">key</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.training.labels.Direction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Direction</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.training.labels.Label" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Label</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.training.labels.Label.key" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">key</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.labels.Direction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Direction</span></code></a> + </li></ul> + </li> <li class="md-nav__item"> @@ -408,7 +442,20 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-training-labels--page-root" class="md-nav__link">labels</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.labels.Label" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Label</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.labels.Label.key" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">key</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.training.labels.Direction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Direction</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -418,8 +465,48 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="labels"> -<h1 id="api-graphnet-training-labels--page-root">labels<a class="headerlink" href="#api-graphnet-training-labels--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.training.labels"> +<span id="labels"></span><h1 id="api-graphnet-training-labels--page-root">labels<a class="headerlink" href="#api-graphnet-training-labels--page-root" title="Link to this heading">¶</a></h1> +<p>Class(es) for constructing training labels at runtime.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.labels.Label"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.labels.</span></span><span class="sig-name descname"><span class="pre">Label</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">key</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/labels.html#Label"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.labels.Label" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">ABC</span></code>, <a class="reference internal" href="graphnet.utilities.logging.html#graphnet.utilities.logging.Logger" title="graphnet.utilities.logging.Logger"><code class="xref py py-class docutils literal notranslate"><span class="pre">Logger</span></code></a></p> +<p>Base <cite>Label</cite> class for producing labels from single <cite>Data</cite> instance.</p> +<p>Construct <cite>Label</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>key</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>) – The name of the field in <cite>Data</cite> where the label will be +stored. That is, <cite>graph[key] = label</cite>.</p> +</dd> +</dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.training.labels.Label.key"> +<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">key</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">str</span></em><a class="headerlink" href="#graphnet.training.labels.Label.key" title="Link to this definition">¶</a></dt> +<dd><p>Return value of <cite>key</cite>.</p> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.labels.Direction"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.labels.</span></span><span class="sig-name descname"><span class="pre">Direction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">key</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">azimuth_key</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">zenith_key</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/labels.html#Direction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.labels.Direction" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.training.labels.Label" title="graphnet.training.labels.Label"><code class="xref py py-class docutils literal notranslate"><span class="pre">Label</span></code></a></p> +<p>Class for producing particle direction/pointing label.</p> +<p>Construct <cite>Direction</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>key</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'direction'</span></code>) – The name of the field in <cite>Data</cite> where the label will be +stored. That is, <cite>graph[key] = label</cite>.</p></li> +<li><p><strong>azimuth_key</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'azimuth'</span></code>) – The name of the pre-existing key in <cite>graph</cite> that will +be used to access the azimiuth angle, used when calculating +the direction.</p></li> +<li><p><strong>zenith_key</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, default: <code class="docutils literal notranslate"><span class="pre">'zenith'</span></code>) – The name of the pre-existing key in <cite>graph</cite> that will +be used to access the zenith angle, used when calculating the +direction.</p></li> +</ul> +</dd> +</dl> +</dd></dl> </section> @@ -469,7 +556,7 @@ <h1 id="api-graphnet-training-labels--page-root">labels<a class="headerlink" hre </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.training.loss_functions.html b/api/graphnet.training.loss_functions.html index 61edd4160..feb15b1bf 100644 --- a/api/graphnet.training.loss_functions.html +++ b/api/graphnet.training.loss_functions.html @@ -358,11 +358,175 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-training-loss-functions--page-root" class="md-nav__link">loss_functions</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LossFunction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LossFunction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LossFunction.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.MSELoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">MSELoss</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.RMSELoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">RMSELoss</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCoshLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LogCoshLoss</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.CrossEntropyLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">CrossEntropyLoss</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.BinaryCrossEntropyLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryCrossEntropyLoss</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCMK" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LogCMK</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCMK.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCMK.backward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">backward()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisherLoss</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk_exact()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk_approx()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisher2DLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher2DLoss</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.EuclideanDistanceLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EuclideanDistanceLoss</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisher3DLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher3DLoss</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.LossFunction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LossFunction</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.LossFunction.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.MSELoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">MSELoss</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.RMSELoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">RMSELoss</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.LogCoshLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LogCoshLoss</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.CrossEntropyLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">CrossEntropyLoss</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.BinaryCrossEntropyLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryCrossEntropyLoss</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.LogCMK" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LogCMK</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.LogCMK.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.LogCMK.backward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">backward()</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.VonMisesFisherLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisherLoss</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk_exact()</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk_approx()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk()</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.VonMisesFisher2DLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher2DLoss</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.EuclideanDistanceLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EuclideanDistanceLoss</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.loss_functions.VonMisesFisher3DLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher3DLoss</span></code></a> + + + </li></ul> + </li> <li class="md-nav__item"> @@ -408,7 +572,52 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-training-loss-functions--page-root" class="md-nav__link">loss_functions</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LossFunction" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LossFunction</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LossFunction.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.MSELoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">MSELoss</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.RMSELoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">RMSELoss</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCoshLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LogCoshLoss</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.CrossEntropyLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">CrossEntropyLoss</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.BinaryCrossEntropyLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BinaryCrossEntropyLoss</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCMK" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">LogCMK</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCMK.forward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">forward()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.LogCMK.backward" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">backward()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisherLoss</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk_exact()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk_approx()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">log_cmk()</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisher2DLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher2DLoss</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.EuclideanDistanceLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">EuclideanDistanceLoss</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.loss_functions.VonMisesFisher3DLoss" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">VonMisesFisher3DLoss</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -418,8 +627,281 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="loss-functions"> -<h1 id="api-graphnet-training-loss-functions--page-root">loss_functions<a class="headerlink" href="#api-graphnet-training-loss-functions--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.training.loss_functions"> +<span id="loss-functions"></span><h1 id="api-graphnet-training-loss-functions--page-root">loss_functions<a class="headerlink" href="#api-graphnet-training-loss-functions--page-root" title="Link to this heading">¶</a></h1> +<p>Collection of loss functions.</p> +<p>All loss functions inherit from <cite>LossFunction</cite> which ensures a common syntax, +handles per-event weights, etc.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.LossFunction"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">LossFunction</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#LossFunction"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.LossFunction" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><code class="xref py py-class docutils literal notranslate"><span class="pre">Model</span></code></a></p> +<p>Base class for loss functions in <cite>graphnet</cite>.</p> +<p>Construct <cite>LossFunction</cite>, saving model config.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.LossFunction.forward"> +<span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">prediction</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">weights</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">return_elements</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#LossFunction.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.LossFunction.forward" title="Link to this definition">¶</a></dt> +<dd><p>Forward pass for all loss functions.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>prediction</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>) – Tensor containing predictions. Shape [N,P]</p></li> +<li><p><strong>target</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code>) – Tensor containing targets. Shape [N,T]</p></li> +<li><p><strong>return_elements</strong> (<code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code>, default: <code class="docutils literal notranslate"><span class="pre">False</span></code>) – Whether elementwise loss terms should be returned. +The alternative is to return the averaged loss across examples.</p></li> +<li><p><strong>weights</strong> (<em>Tensor</em><em> | </em><em>None</em>) – </p></li> +</ul> +</dd> +<dt class="field-even">Return type<span class="colon">:</span></dt> +<dd class="field-even"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-odd">Returns<span class="colon">:</span></dt> +<dd class="field-odd"><p>Loss, either averaged to a scalar (if <cite>return_elements = False</cite>) or +elementwise terms with shape [N,] (if <cite>return_elements = True</cite>).</p> +</dd> +</dl> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.MSELoss"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">MSELoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#MSELoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.MSELoss" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a></p> +<p>Mean squared error loss.</p> +<p>Construct <cite>LossFunction</cite>, saving model config.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.RMSELoss"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">RMSELoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#RMSELoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.RMSELoss" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.MSELoss" title="graphnet.training.loss_functions.MSELoss"><code class="xref py py-class docutils literal notranslate"><span class="pre">MSELoss</span></code></a></p> +<p>Root mean squared error loss.</p> +<p>Construct <cite>LossFunction</cite>, saving model config.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.LogCoshLoss"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">LogCoshLoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#LogCoshLoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.LogCoshLoss" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a></p> +<p>Log-cosh loss function.</p> +<p>Acts like x^2 for small x; and like <a href="#id1"><span class="problematic" id="id2">|x|</span></a> for large x.</p> +<p>Construct <cite>LossFunction</cite>, saving model config.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.CrossEntropyLoss"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">CrossEntropyLoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">options</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#CrossEntropyLoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.CrossEntropyLoss" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a></p> +<p>Compute cross-entropy loss for classification tasks.</p> +<p>Predictions are an [N, num_class]-matrix of logits (i.e., non-softmax’ed +probabilities), and targets are an [N,1]-matrix with integer values in +(0, num_classes - 1).</p> +<p>Construct CrossEntropyLoss.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>options</strong> (<em>int</em><em> | </em><em>List</em><em>[</em><em>Any</em><em>] </em><em>| </em><em>Dict</em><em>[</em><em>Any</em><em>, </em><em>int</em><em>]</em>) – </p></li> +<li><p><strong>args</strong> (<em>Any</em>) – </p></li> +<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.BinaryCrossEntropyLoss"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">BinaryCrossEntropyLoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#BinaryCrossEntropyLoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.BinaryCrossEntropyLoss" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a></p> +<p>Compute binary cross entropy loss.</p> +<p>Predictions are vector probabilities (i.e., values between 0 and 1), and +targets should be 0 and 1.</p> +<p>Construct <cite>LossFunction</cite>, saving model config.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.LogCMK"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">LogCMK</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#LogCMK"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.LogCMK" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">Function</span></code></p> +<p>MIT License.</p> +<p>Copyright (c) 2019 Max Ryabinin</p> +<p>Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the “Software”), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions:</p> +<p>The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software.</p> +<p>THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +_____________________</p> +<p>From [<a class="reference external" href="https://github.com/mryab/vmf_loss/blob/master/losses.py">https://github.com/mryab/vmf_loss/blob/master/losses.py</a>] +Modified to use modified Bessel function instead of exponentially scaled ditto +(i.e. <cite>.ive</cite> -> <cite>.iv</cite>) as indiciated in [1812.04616] in spite of suggestion in +Sec. 8.2 of this paper. The change has been validated through comparison with +exact calculations for <cite>m=2</cite> and <cite>m=3</cite> and found to yield the correct results.</p> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.LogCMK.forward"> +<em class="property"><span class="pre">static</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">ctx</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">m</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kappa</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#LogCMK.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.LogCMK.forward" title="Link to this definition">¶</a></dt> +<dd><p>Forward pass.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>ctx</strong> (<em>Any</em>) – </p></li> +<li><p><strong>m</strong> (<em>int</em>) – </p></li> +<li><p><strong>kappa</strong> (<em>Tensor</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.LogCMK.backward"> +<em class="property"><span class="pre">static</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">ctx</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">grad_output</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#LogCMK.backward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.LogCMK.backward" title="Link to this definition">¶</a></dt> +<dd><p>Backward pass.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>ctx</strong> (<em>Any</em>) – </p></li> +<li><p><strong>grad_output</strong> (<em>Tensor</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.VonMisesFisherLoss"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">VonMisesFisherLoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#VonMisesFisherLoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.VonMisesFisherLoss" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a></p> +<p>General class for calculating von Mises-Fisher loss.</p> +<p>Requires implementation for specific dimension <cite>m</cite> in which the target and +prediction vectors need to be prepared.</p> +<p>Construct <cite>LossFunction</cite>, saving model config.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p> +</dd> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact"> +<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">log_cmk_exact</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">m</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kappa</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#VonMisesFisherLoss.log_cmk_exact"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact" title="Link to this definition">¶</a></dt> +<dd><p>Calculate $log C_{m}(k)$ term in von Mises-Fisher loss exactly.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>m</strong> (<em>int</em>) – </p></li> +<li><p><strong>kappa</strong> (<em>Tensor</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx"> +<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">log_cmk_approx</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">m</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kappa</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#VonMisesFisherLoss.log_cmk_approx"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx" title="Link to this definition">¶</a></dt> +<dd><p>Calculate $log C_{m}(k)$ term in von Mises-Fisher loss approx.</p> +<p>[<a class="reference external" href="https://arxiv.org/abs/1812.04616">https://arxiv.org/abs/1812.04616</a>] Sec. 8.2 with additional minus sign.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>m</strong> (<em>int</em>) – </p></li> +<li><p><strong>kappa</strong> (<em>Tensor</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk"> +<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">log_cmk</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">m</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kappa</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kappa_switch</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#VonMisesFisherLoss.log_cmk"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk" title="Link to this definition">¶</a></dt> +<dd><p>Calculate $log C_{m}(k)$ term in von Mises-Fisher loss.</p> +<p>Since <cite>log_cmk_exact</cite> is diverges for <cite>kappa</cite> >~ 700 (using float64 +precision), and since <cite>log_cmk_approx</cite> is unaccurate for small <cite>kappa</cite>, +this method automatically switches between the two at <cite>kappa_switch</cite>, +ensuring continuity at this point.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>m</strong> (<em>int</em>) – </p></li> +<li><p><strong>kappa</strong> (<em>Tensor</em>) – </p></li> +<li><p><strong>kappa_switch</strong> (<em>float</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.VonMisesFisher2DLoss"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">VonMisesFisher2DLoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#VonMisesFisher2DLoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.VonMisesFisher2DLoss" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.VonMisesFisherLoss" title="graphnet.training.loss_functions.VonMisesFisherLoss"><code class="xref py py-class docutils literal notranslate"><span class="pre">VonMisesFisherLoss</span></code></a></p> +<p>von Mises-Fisher loss function vectors in the 2D plane.</p> +<p>Construct <cite>LossFunction</cite>, saving model config.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.EuclideanDistanceLoss"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">EuclideanDistanceLoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#EuclideanDistanceLoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.EuclideanDistanceLoss" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.LossFunction" title="graphnet.training.loss_functions.LossFunction"><code class="xref py py-class docutils literal notranslate"><span class="pre">LossFunction</span></code></a></p> +<p>Mean squared error in three dimensions.</p> +<p>Construct <cite>LossFunction</cite>, saving model config.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.training.loss_functions.VonMisesFisher3DLoss"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.training.loss_functions.</span></span><span class="sig-name descname"><span class="pre">VonMisesFisher3DLoss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/loss_functions.html#VonMisesFisher3DLoss"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.loss_functions.VonMisesFisher3DLoss" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="#graphnet.training.loss_functions.VonMisesFisherLoss" title="graphnet.training.loss_functions.VonMisesFisherLoss"><code class="xref py py-class docutils literal notranslate"><span class="pre">VonMisesFisherLoss</span></code></a></p> +<p>von Mises-Fisher loss function vectors in the 3D plane.</p> +<p>Construct <cite>LossFunction</cite>, saving model config.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><p><strong>kwargs</strong> (<em>Any</em>) – </p> +</dd> +</dl> +</dd></dl> </section> @@ -469,7 +951,7 @@ <h1 id="api-graphnet-training-loss-functions--page-root">loss_functions<a class= </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.training.utils.html b/api/graphnet.training.utils.html index d29636c35..ac1d9919d 100644 --- a/api/graphnet.training.utils.html +++ b/api/graphnet.training.utils.html @@ -365,11 +365,61 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-training-utils--page-root" class="md-nav__link">utils</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.utils.collate_fn" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.utils.make_dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">make_dataloader()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.utils.make_train_validation_dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">make_train_validation_dataloader()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.utils.get_predictions" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_predictions()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.utils.save_results" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_results()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.training.utils.collate_fn" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.utils.make_dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">make_dataloader()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.utils.make_train_validation_dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">make_train_validation_dataloader()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.utils.get_predictions" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_predictions()</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.training.utils.save_results" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_results()</span></code></a> + + + </li></ul> + </li> <li class="md-nav__item"> @@ -408,7 +458,22 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-training-utils--page-root" class="md-nav__link">utils</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.training.utils.collate_fn" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">collate_fn()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.utils.make_dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">make_dataloader()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.utils.make_train_validation_dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">make_train_validation_dataloader()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.utils.get_predictions" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_predictions()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.training.utils.save_results" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_results()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -418,8 +483,128 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="utils"> -<h1 id="api-graphnet-training-utils--page-root">utils<a class="headerlink" href="#api-graphnet-training-utils--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.training.utils"> +<span id="utils"></span><h1 id="api-graphnet-training-utils--page-root">utils<a class="headerlink" href="#api-graphnet-training-utils--page-root" title="Link to this heading">¶</a></h1> +<p>Utility functions for <cite>graphnet.training</cite>.</p> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.training.utils.collate_fn"> +<span class="sig-prename descclassname"><span class="pre">graphnet.training.utils.</span></span><span class="sig-name descname"><span class="pre">collate_fn</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">graphs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/utils.html#collate_fn"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.utils.collate_fn" title="Link to this definition">¶</a></dt> +<dd><p>Remove graphs with less than two DOM hits.</p> +<p>Should not occur in “production.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Batch</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>graphs</strong> (<em>List</em><em>[</em><em>Data</em><em>]</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.training.utils.make_dataloader"> +<span class="sig-prename descclassname"><span class="pre">graphnet.training.utils.</span></span><span class="sig-name descname"><span class="pre">make_dataloader</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">db</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemaps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shuffle</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_workers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">persistent_workers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">string_selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">index_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">labels</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/utils.html#make_dataloader"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.utils.make_dataloader" title="Link to this definition">¶</a></dt> +<dd><p>Construct <cite>DataLoader</cite> instance.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">DataLoader</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>db</strong> (<em>str</em>) – </p></li> +<li><p><strong>pulsemaps</strong> (<em>str</em><em> | </em><em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>graph_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><em>GraphDefinition</em></a><em> | </em><em>None</em>) – </p></li> +<li><p><strong>features</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>truth</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>batch_size</strong> (<em>int</em>) – </p></li> +<li><p><strong>shuffle</strong> (<em>bool</em>) – </p></li> +<li><p><strong>selection</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>None</em>) – </p></li> +<li><p><strong>num_workers</strong> (<em>int</em>) – </p></li> +<li><p><strong>persistent_workers</strong> (<em>bool</em>) – </p></li> +<li><p><strong>node_truth</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li> +<li><p><strong>truth_table</strong> (<em>str</em>) – </p></li> +<li><p><strong>node_truth_table</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>string_selection</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>None</em>) – </p></li> +<li><p><strong>loss_weight_table</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>loss_weight_column</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>index_column</strong> (<em>str</em>) – </p></li> +<li><p><strong>labels</strong> (<em>Dict</em><em>[</em><em>str</em><em>, </em><em>Callable</em><em>] </em><em>| </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.training.utils.make_train_validation_dataloader"> +<span class="sig-prename descclassname"><span class="pre">graphnet.training.utils.</span></span><span class="sig-name descname"><span class="pre">make_train_validation_dataloader</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">db</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemaps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">database_indices</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">seed</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">test_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_workers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">persistent_workers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">string_selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">index_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">labels</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/utils.html#make_train_validation_dataloader"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.utils.make_train_validation_dataloader" title="Link to this definition">¶</a></dt> +<dd><p>Construct train and test <cite>DataLoader</cite> instances.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Tuple</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">DataLoader</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">DataLoader</span></code>]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>db</strong> (<em>str</em>) – </p></li> +<li><p><strong>graph_definition</strong> (<a class="reference internal" href="graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition" title="graphnet.models.graphs.graph_definition.GraphDefinition"><em>GraphDefinition</em></a><em> | </em><em>None</em>) – </p></li> +<li><p><strong>selection</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>None</em>) – </p></li> +<li><p><strong>pulsemaps</strong> (<em>str</em><em> | </em><em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>features</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>truth</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>batch_size</strong> (<em>int</em>) – </p></li> +<li><p><strong>database_indices</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>None</em>) – </p></li> +<li><p><strong>seed</strong> (<em>int</em>) – </p></li> +<li><p><strong>test_size</strong> (<em>float</em>) – </p></li> +<li><p><strong>num_workers</strong> (<em>int</em>) – </p></li> +<li><p><strong>persistent_workers</strong> (<em>bool</em>) – </p></li> +<li><p><strong>node_truth</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>truth_table</strong> (<em>str</em>) – </p></li> +<li><p><strong>node_truth_table</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>string_selection</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>None</em>) – </p></li> +<li><p><strong>loss_weight_column</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>loss_weight_table</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>index_column</strong> (<em>str</em>) – </p></li> +<li><p><strong>labels</strong> (<em>Dict</em><em>[</em><em>str</em><em>, </em><em>Callable</em><em>] </em><em>| </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.training.utils.get_predictions"> +<span class="sig-prename descclassname"><span class="pre">graphnet.training.utils.</span></span><span class="sig-name descname"><span class="pre">get_predictions</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">trainer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dataloader</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">prediction_columns</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_level</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">additional_attributes</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/utils.html#get_predictions"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.utils.get_predictions" title="Link to this definition">¶</a></dt> +<dd><p>Get <cite>model</cite> predictions on <cite>dataloader</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">DataFrame</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>trainer</strong> (<em>Trainer</em>) – </p></li> +<li><p><strong>model</strong> (<a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><em>Model</em></a>) – </p></li> +<li><p><strong>dataloader</strong> (<em>DataLoader</em>) – </p></li> +<li><p><strong>prediction_columns</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>node_level</strong> (<em>bool</em>) – </p></li> +<li><p><strong>additional_attributes</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.training.utils.save_results"> +<span class="sig-prename descclassname"><span class="pre">graphnet.training.utils.</span></span><span class="sig-name descname"><span class="pre">save_results</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">db</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tag</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">results</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">archive</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">model</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/training/utils.html#save_results"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.training.utils.save_results" title="Link to this definition">¶</a></dt> +<dd><p>Save trained model and prediction <cite>results</cite> in <cite>db</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>db</strong> (<em>str</em>) – </p></li> +<li><p><strong>tag</strong> (<em>str</em>) – </p></li> +<li><p><strong>results</strong> (<em>DataFrame</em>) – </p></li> +<li><p><strong>archive</strong> (<em>str</em>) – </p></li> +<li><p><strong>model</strong> (<a class="reference internal" href="graphnet.models.model.html#graphnet.models.model.Model" title="graphnet.models.model.Model"><em>Model</em></a>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> </section> @@ -469,7 +654,7 @@ <h1 id="api-graphnet-training-utils--page-root">utils<a class="headerlink" href= </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.training.weight_fitting.html b/api/graphnet.training.weight_fitting.html index aa7cedd58..78198cb33 100644 --- a/api/graphnet.training.weight_fitting.html +++ b/api/graphnet.training.weight_fitting.html @@ -611,7 +611,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.utilities.argparse.html b/api/graphnet.utilities.argparse.html index 60f4ad179..814d2bdf2 100644 --- a/api/graphnet.utilities.argparse.html +++ b/api/graphnet.utilities.argparse.html @@ -639,7 +639,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.utilities.config.base_config.html b/api/graphnet.utilities.config.base_config.html index dda866be3..48747c41d 100644 --- a/api/graphnet.utilities.config.base_config.html +++ b/api/graphnet.utilities.config.base_config.html @@ -357,11 +357,81 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-utilities-config-base-config--page-root" class="md-nav__link">base_config</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BaseConfig</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.load" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.dump" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">dump()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.get_all_argument_values" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_all_argument_values()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.base_config.BaseConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BaseConfig</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.base_config.BaseConfig.load" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.base_config.BaseConfig.dump" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">dump()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.base_config.BaseConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.base_config.BaseConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.base_config.BaseConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a> + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.base_config.get_all_argument_values" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_all_argument_values()</span></code></a> + + + </li></ul> + </li> <li class="md-nav__item"> @@ -465,7 +535,28 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-utilities-config-base-config--page-root" class="md-nav__link">base_config</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">BaseConfig</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.load" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">load()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.dump" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">dump()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.BaseConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.base_config.get_all_argument_values" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_all_argument_values()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -475,8 +566,88 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="base-config"> -<h1 id="api-graphnet-utilities-config-base-config--page-root">base_config<a class="headerlink" href="#api-graphnet-utilities-config-base-config--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.utilities.config.base_config"> +<span id="base-config"></span><h1 id="api-graphnet-utilities-config-base-config--page-root">base_config<a class="headerlink" href="#api-graphnet-utilities-config-base-config--page-root" title="Link to this heading">¶</a></h1> +<p>Base config class(es).</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.utilities.config.base_config.BaseConfig"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.base_config.</span></span><span class="sig-name descname"><span class="pre">BaseConfig</span></span><a class="reference internal" href="../_modules/graphnet/utilities/config/base_config.html#BaseConfig"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.base_config.BaseConfig" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">BaseModel</span></code></p> +<p>Base class for Configs.</p> +<p>Create a new model by parsing and validating input data from keyword arguments.</p> +<p>Raises [<cite>ValidationError</cite>][pydantic_core.ValidationError] if the input data cannot be +validated to form a valid model.</p> +<p><cite>__init__</cite> uses <cite>__pydantic_self__</cite> instead of the more common <cite>self</cite> for the first arg to +allow <cite>self</cite> as a field name.</p> +<dl class="field-list simple"> +</dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.utilities.config.base_config.BaseConfig.load"> +<em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">load</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/base_config.html#BaseConfig.load"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.base_config.BaseConfig.load" title="Link to this definition">¶</a></dt> +<dd><p>Load BaseConfig from <cite>path</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><a class="reference internal" href="#graphnet.utilities.config.base_config.BaseConfig" title="graphnet.utilities.config.base_config.BaseConfig"><code class="xref py py-class docutils literal notranslate"><span class="pre">BaseConfig</span></code></a></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>path</strong> (<em>str</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.utilities.config.base_config.BaseConfig.dump"> +<span class="sig-name descname"><span class="pre">dump</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/base_config.html#BaseConfig.dump"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.base_config.BaseConfig.dump" title="Link to this definition">¶</a></dt> +<dd><p>Save BaseConfig to <cite>path</cite> as YAML file, or return as string.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>path</strong> (<em>str</em><em> | </em><em>None</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.utilities.config.base_config.BaseConfig.as_dict"> +<span class="sig-name descname"><span class="pre">as_dict</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/base_config.html#BaseConfig.as_dict"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.base_config.BaseConfig.as_dict" title="Link to this definition">¶</a></dt> +<dd><p>Represent BaseConfig as a dict.</p> +<p>This builds on <cite>BaseModel.dict()</cite> but can be overwritten.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>]]</p> +</dd> +</dl> +</dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.base_config.BaseConfig.model_config"> +<span class="sig-name descname"><span class="pre">model_config</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[ConfigDict]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{}</span></em><a class="headerlink" href="#graphnet.utilities.config.base_config.BaseConfig.model_config" title="Link to this definition">¶</a></dt> +<dd><p>Configuration for the model, should be a dictionary conforming to [<cite>ConfigDict</cite>][pydantic.config.ConfigDict].</p> +</dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.base_config.BaseConfig.model_fields"> +<span class="sig-name descname"><span class="pre">model_fields</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[dict[str,</span> <span class="pre">FieldInfo]]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{}</span></em><a class="headerlink" href="#graphnet.utilities.config.base_config.BaseConfig.model_fields" title="Link to this definition">¶</a></dt> +<dd><p>Metadata about the fields defined on the model, +mapping of field names to [<cite>FieldInfo</cite>][pydantic.fields.FieldInfo].</p> +<p>This replaces <cite>Model.__fields__</cite> from Pydantic V1.</p> +</dd></dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.utilities.config.base_config.get_all_argument_values"> +<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.base_config.</span></span><span class="sig-name descname"><span class="pre">get_all_argument_values</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">fn</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/base_config.html#get_all_argument_values"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.base_config.get_all_argument_values" title="Link to this definition">¶</a></dt> +<dd><p>Return dict of all argument values to <cite>fn</cite>, including defaults.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>fn</strong> (<em>Callable</em>) – </p></li> +<li><p><strong>args</strong> (<em>Any</em>) – </p></li> +<li><p><strong>kwargs</strong> (<em>Any</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> </section> @@ -526,7 +697,7 @@ <h1 id="api-graphnet-utilities-config-base-config--page-root">base_config<a clas </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.utilities.config.configurable.html b/api/graphnet.utilities.config.configurable.html index cc82e52e3..7654534eb 100644 --- a/api/graphnet.utilities.config.configurable.html +++ b/api/graphnet.utilities.config.configurable.html @@ -364,11 +364,54 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-utilities-config-configurable--page-root" class="md-nav__link">configurable</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Configurable</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable.config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">config</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable.save_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_config()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.configurable.Configurable" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Configurable</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.configurable.Configurable.config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">config</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.configurable.Configurable.save_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_config()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.configurable.Configurable.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a> + </li></ul> + + </li></ul> + </li> <li class="md-nav__item"> @@ -465,7 +508,22 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-utilities-config-configurable--page-root" class="md-nav__link">configurable</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">Configurable</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable.config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">config</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable.save_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_config()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.configurable.Configurable.from_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">from_config()</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -475,8 +533,49 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="configurable"> -<h1 id="api-graphnet-utilities-config-configurable--page-root">configurable<a class="headerlink" href="#api-graphnet-utilities-config-configurable--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.utilities.config.configurable"> +<span id="configurable"></span><h1 id="api-graphnet-utilities-config-configurable--page-root">configurable<a class="headerlink" href="#api-graphnet-utilities-config-configurable--page-root" title="Link to this heading">¶</a></h1> +<p>Bases for all configurable classes in <cite>graphnet</cite>.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.utilities.config.configurable.Configurable"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.configurable.</span></span><span class="sig-name descname"><span class="pre">Configurable</span></span><a class="reference internal" href="../_modules/graphnet/utilities/config/configurable.html#Configurable"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.configurable.Configurable" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">ABC</span></code></p> +<p>Base class for all configurable classes in graphnet.</p> +<p>Construct <cite>Configurable</cite>.</p> +<dl class="field-list simple"> +</dl> +<dl class="py property"> +<dt class="sig sig-object py" id="graphnet.utilities.config.configurable.Configurable.config"> +<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">config</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference internal" href="graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig" title="graphnet.utilities.config.base_config.BaseConfig"><span class="pre">BaseConfig</span></a></em><a class="headerlink" href="#graphnet.utilities.config.configurable.Configurable.config" title="Link to this definition">¶</a></dt> +<dd><p>Return configuration to re-create the instance.</p> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.utilities.config.configurable.Configurable.save_config"> +<span class="sig-name descname"><span class="pre">save_config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/configurable.html#Configurable.save_config"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.configurable.Configurable.save_config" title="Link to this definition">¶</a></dt> +<dd><p>Save Config to <cite>path</cite> as YAML file.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>path</strong> (<em>str</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.utilities.config.configurable.Configurable.from_config"> +<em class="property"><span class="pre">abstract</span><span class="w"> </span><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">from_config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">source</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/configurable.html#Configurable.from_config"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.configurable.Configurable.from_config" title="Link to this definition">¶</a></dt> +<dd><p>Construct instance from <cite>source</cite> configuration.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>source</strong> (<a class="reference internal" href="graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig" title="graphnet.utilities.config.base_config.BaseConfig"><em>BaseConfig</em></a><em> | </em><em>str</em>) – </p> +</dd> +</dl> +</dd></dl> +</dd></dl> </section> @@ -526,7 +625,7 @@ <h1 id="api-graphnet-utilities-config-configurable--page-root">configurable<a cl </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.utilities.config.dataset_config.html b/api/graphnet.utilities.config.dataset_config.html index 27d11fb12..d3b202a27 100644 --- a/api/graphnet.utilities.config.dataset_config.html +++ b/api/graphnet.utilities.config.dataset_config.html @@ -371,11 +371,198 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-utilities-config-dataset-config--page-root" class="md-nav__link">dataset_config</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DatasetConfig</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.path" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">path</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">pulsemaps</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.features" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">features</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">node_truth</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.index_column" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">index_column</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth_table</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">node_truth_table</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.string_selection" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">string_selection</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.selection" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">selection</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_table</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_column</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_default_value</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.seed" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">seed</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">graph_definition</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.save_dataset_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_dataset_config()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DatasetConfig</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.path" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">path</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">pulsemaps</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.features" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">features</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">node_truth</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.index_column" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">index_column</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth_table</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">node_truth_table</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.string_selection" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">string_selection</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.selection" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">selection</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_table</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_column</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_default_value</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.seed" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">seed</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">graph_definition</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.dataset_config.save_dataset_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_dataset_config()</span></code></a> + + + </li></ul> + </li> <li class="md-nav__item"> @@ -465,7 +652,54 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-utilities-config-dataset-config--page-root" class="md-nav__link">dataset_config</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">DatasetConfig</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.path" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">path</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">pulsemaps</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.features" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">features</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">node_truth</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.index_column" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">index_column</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">truth_table</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">node_truth_table</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.string_selection" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">string_selection</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.selection" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">selection</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_table</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_column</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">loss_weight_default_value</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.seed" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">seed</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">graph_definition</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.dataset_config.save_dataset_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_dataset_config()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -475,8 +709,206 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="dataset-config"> -<h1 id="api-graphnet-utilities-config-dataset-config--page-root">dataset_config<a class="headerlink" href="#api-graphnet-utilities-config-dataset-config--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.utilities.config.dataset_config"> +<span id="dataset-config"></span><h1 id="api-graphnet-utilities-config-dataset-config--page-root">dataset_config<a class="headerlink" href="#api-graphnet-utilities-config-dataset-config--page-root" title="Link to this heading">¶</a></h1> +<p>Config classes for the <cite>graphnet.data.dataset</cite> module.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.dataset_config.</span></span><span class="sig-name descname"><span class="pre">DatasetConfig</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">path</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pulsemaps</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">index_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">node_truth_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">string_selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">selection</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_table</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_column</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss_weight_default_value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">seed</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">graph_definition</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/dataset_config.html#DatasetConfig"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig" title="graphnet.utilities.config.base_config.BaseConfig"><code class="xref py py-class docutils literal notranslate"><span class="pre">BaseConfig</span></code></a></p> +<p>Configuration for all <a href="#id1"><span class="problematic" id="id2">`</span></a>Dataset`s.</p> +<p>Construct <cite>DataConfig</cite>.</p> +<p>Can be used for dataset configuration as code, thereby making dataset +construction more transparent and reproducible.</p> +<p class="rubric">Examples</p> +<p>In one session, do:</p> +<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">dataset</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> +<span class="gp">>>> </span><span class="n">dataset</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dump</span><span class="p">()</span> +<span class="go">path: (...)</span> +<span class="go">pulsemaps:</span> +<span class="go"> - (...)</span> +<span class="go">(...)</span> +<span class="gp">>>> </span><span class="n">dataset</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="s2">"dataset.yml"</span><span class="p">)</span> +</pre></div> +</div> +<p>In another session, you can then do: +>>> dataset = Dataset.from_config(“dataset.yml”)</p> +<p># Uniquely for <cite>DatasetConfig</cite>, you can also define and load +# multiple datasets +>>> dataset.config.selection = {</p> +<blockquote> +<div><p>“train”: “event_no % 2 == 0”, +“test”: “event_no % 2 == 1”,</p> +</div></blockquote> +<p>} +>>> dataset.config.dump(“dataset.yml”) +>>> datasets: Dict[str, Dataset] = Dataset.from_config(</p> +<blockquote> +<div><p>“dataset.yml”</p> +</div></blockquote> +<p>) +>>> datasets +{</p> +<blockquote> +<div><p>“train”: Dataset(…), +“test”: Dataset(…),</p> +</div></blockquote> +<p>}</p> +<p># You can also combine multiple selections into a single, named +# dataset +>>> dataset.config.selection = {</p> +<blockquote> +<div><dl class="simple"> +<dt>“train”: [</dt><dd><p>“event_no % 2 == 0 & abs(pid) == 12”, +“event_no % 2 == 0 & abs(pid) == 14”, +“event_no % 2 == 0 & abs(pid) == 16”,</p> +</dd> +</dl> +<p>], +(…)</p> +</div></blockquote> +<p>} +>>> dataset.config.dump(“dataset.yml”) +>>> datasets: Dict[str, EnsembleDataset] = Dataset.from_config(</p> +<blockquote> +<div><p>“dataset.yml”</p> +</div></blockquote> +<p>) +>>> datasets +{</p> +<blockquote> +<div><p>“train”: EnsembleDataset(…), +(…)</p> +</div></blockquote> +<p>}</p> +<p># Finally, you can still reference existing selection files in CSV +# or JSON formats: +>>> dataset.config.selection = {</p> +<blockquote> +<div><p>“train”: “50000 random events ~ train_selection.csv”, +“test”: “test_selection.csv”,</p> +</div></blockquote> +<p>}</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>path</strong> (<em>str</em><em> | </em><em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>pulsemaps</strong> (<em>str</em><em> | </em><em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>features</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>truth</strong> (<em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>node_truth</strong> (<em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>None</em>) – </p></li> +<li><p><strong>index_column</strong> (<em>str</em>) – </p></li> +<li><p><strong>truth_table</strong> (<em>str</em>) – </p></li> +<li><p><strong>node_truth_table</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>string_selection</strong> (<em>List</em><em>[</em><em>int</em><em>] </em><em>| </em><em>None</em>) – </p></li> +<li><p><strong>selection</strong> (<em>str</em><em> | </em><em>List</em><em>[</em><em>str</em><em>] </em><em>| </em><em>List</em><em>[</em><em>int</em><em> | </em><em>List</em><em>[</em><em>int</em><em>]</em><em>] </em><em>| </em><em>Dict</em><em>[</em><em>str</em><em>, </em><em>str</em><em> | </em><em>List</em><em>[</em><em>str</em><em>]</em><em>] </em><em>| </em><em>None</em>) – </p></li> +<li><p><strong>loss_weight_table</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>loss_weight_column</strong> (<em>str</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>loss_weight_default_value</strong> (<em>float</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>seed</strong> (<em>int</em><em> | </em><em>None</em>) – </p></li> +<li><p><strong>graph_definition</strong> (<em>Any</em>) – </p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.path"> +<span class="sig-name descname"><span class="pre">path</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.path" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps"> +<span class="sig-name descname"><span class="pre">pulsemaps</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.features"> +<span class="sig-name descname"><span class="pre">features</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.features" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.truth"> +<span class="sig-name descname"><span class="pre">truth</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.node_truth"> +<span class="sig-name descname"><span class="pre">node_truth</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.index_column"> +<span class="sig-name descname"><span class="pre">index_column</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.index_column" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.truth_table"> +<span class="sig-name descname"><span class="pre">truth_table</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.truth_table" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table"> +<span class="sig-name descname"><span class="pre">node_truth_table</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.string_selection"> +<span class="sig-name descname"><span class="pre">string_selection</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code><span class="pre">]]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.string_selection" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.selection"> +<span class="sig-name descname"><span class="pre">selection</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">],</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code><span class="pre">,</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code><span class="pre">]]],</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]]],</span> <code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.selection" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table"> +<span class="sig-name descname"><span class="pre">loss_weight_table</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column"> +<span class="sig-name descname"><span class="pre">loss_weight_column</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value"> +<span class="sig-name descname"><span class="pre">loss_weight_default_value</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.seed"> +<span class="sig-name descname"><span class="pre">seed</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Optional</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.seed" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition"> +<span class="sig-name descname"><span class="pre">graph_definition</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.as_dict"> +<span class="sig-name descname"><span class="pre">as_dict</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/dataset_config.html#DatasetConfig.as_dict"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.as_dict" title="Link to this definition">¶</a></dt> +<dd><p>Represent ModelConfig as a dict.</p> +<p>This builds on <cite>BaseModel.dict()</cite> but wraps the output in a single-key +dictionary to make it unambiguous to identify model arguments that are +themselves models.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>]]</p> +</dd> +</dl> +</dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.model_config"> +<span class="sig-name descname"><span class="pre">model_config</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[ConfigDict]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{}</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_config" title="Link to this definition">¶</a></dt> +<dd><p>Configuration for the model, should be a dictionary conforming to [<cite>ConfigDict</cite>][pydantic.config.ConfigDict].</p> +</dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.DatasetConfig.model_fields"> +<span class="sig-name descname"><span class="pre">model_fields</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[dict[str,</span> <span class="pre">FieldInfo]]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{'features':</span> <span class="pre">FieldInfo(annotation=List[str],</span> <span class="pre">required=True),</span> <span class="pre">'graph_definition':</span> <span class="pre">FieldInfo(annotation=Any,</span> <span class="pre">required=False),</span> <span class="pre">'index_column':</span> <span class="pre">FieldInfo(annotation=str,</span> <span class="pre">required=False,</span> <span class="pre">default='event_no'),</span> <span class="pre">'loss_weight_column':</span> <span class="pre">FieldInfo(annotation=Union[str,</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'loss_weight_default_value':</span> <span class="pre">FieldInfo(annotation=Union[float,</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'loss_weight_table':</span> <span class="pre">FieldInfo(annotation=Union[str,</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'node_truth':</span> <span class="pre">FieldInfo(annotation=Union[List[str],</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'node_truth_table':</span> <span class="pre">FieldInfo(annotation=Union[str,</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'path':</span> <span class="pre">FieldInfo(annotation=Union[str,</span> <span class="pre">List[str]],</span> <span class="pre">required=True),</span> <span class="pre">'pulsemaps':</span> <span class="pre">FieldInfo(annotation=Union[str,</span> <span class="pre">List[str]],</span> <span class="pre">required=True),</span> <span class="pre">'seed':</span> <span class="pre">FieldInfo(annotation=Union[int,</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'selection':</span> <span class="pre">FieldInfo(annotation=Union[str,</span> <span class="pre">List[str],</span> <span class="pre">List[Union[int,</span> <span class="pre">List[int]]],</span> <span class="pre">Dict[str,</span> <span class="pre">Union[str,</span> <span class="pre">List[str]]],</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'string_selection':</span> <span class="pre">FieldInfo(annotation=Union[List[int],</span> <span class="pre">NoneType],</span> <span class="pre">required=False),</span> <span class="pre">'truth':</span> <span class="pre">FieldInfo(annotation=List[str],</span> <span class="pre">required=True),</span> <span class="pre">'truth_table':</span> <span class="pre">FieldInfo(annotation=str,</span> <span class="pre">required=False,</span> <span class="pre">default='truth')}</span></em><a class="headerlink" href="#graphnet.utilities.config.dataset_config.DatasetConfig.model_fields" title="Link to this definition">¶</a></dt> +<dd><p>Metadata about the fields defined on the model, +mapping of field names to [<cite>FieldInfo</cite>][pydantic.fields.FieldInfo].</p> +<p>This replaces <cite>Model.__fields__</cite> from Pydantic V1.</p> +</dd></dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.utilities.config.dataset_config.save_dataset_config"> +<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.dataset_config.</span></span><span class="sig-name descname"><span class="pre">save_dataset_config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">init_fn</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/dataset_config.html#save_dataset_config"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.dataset_config.save_dataset_config" title="Link to this definition">¶</a></dt> +<dd><p>Save the arguments to <cite>__init__</cite> functions as member <cite>DatasetConfig</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>init_fn</strong> (<em>Callable</em>) – </p> +</dd> +</dl> +</dd></dl> </section> @@ -526,7 +958,7 @@ <h1 id="api-graphnet-utilities-config-dataset-config--page-root">dataset_config< </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.utilities.config.html b/api/graphnet.utilities.config.html index bba4a6957..59a759d23 100644 --- a/api/graphnet.utilities.config.html +++ b/api/graphnet.utilities.config.html @@ -474,17 +474,44 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="config"> -<h1 id="api-graphnet-utilities-config--page-root">config<a class="headerlink" href="#api-graphnet-utilities-config--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.utilities.config"> +<span id="config"></span><h1 id="api-graphnet-utilities-config--page-root">config<a class="headerlink" href="#api-graphnet-utilities-config--page-root" title="Link to this heading">¶</a></h1> +<p>Modules for configuration files for use across <cite>graphnet</cite>.</p> <p><h2> Submodules </h2></p> <div class="toctree-wrapper compound"> <ul> -<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.base_config.html">base_config</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.configurable.html">configurable</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.dataset_config.html">dataset_config</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.model_config.html">model_config</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.parsing.html">parsing</a></li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.training_config.html">training_config</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.base_config.html">base_config</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig"><code class="docutils literal notranslate"><span class="pre">BaseConfig</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.get_all_argument_values"><code class="docutils literal notranslate"><span class="pre">get_all_argument_values()</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.configurable.html">configurable</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable"><code class="docutils literal notranslate"><span class="pre">Configurable</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.dataset_config.html">dataset_config</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig"><code class="docutils literal notranslate"><span class="pre">DatasetConfig</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.save_dataset_config"><code class="docutils literal notranslate"><span class="pre">save_dataset_config()</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.model_config.html">model_config</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig"><code class="docutils literal notranslate"><span class="pre">ModelConfig</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.save_model_config"><code class="docutils literal notranslate"><span class="pre">save_model_config()</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.parsing.html">parsing</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.traverse_and_apply"><code class="docutils literal notranslate"><span class="pre">traverse_and_apply()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.list_all_submodules"><code class="docutils literal notranslate"><span class="pre">list_all_submodules()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.get_all_grapnet_classes"><code class="docutils literal notranslate"><span class="pre">get_all_grapnet_classes()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.is_graphnet_module"><code class="docutils literal notranslate"><span class="pre">is_graphnet_module()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.is_graphnet_class"><code class="docutils literal notranslate"><span class="pre">is_graphnet_class()</span></code></a></li> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.get_graphnet_classes"><code class="docutils literal notranslate"><span class="pre">get_graphnet_classes()</span></code></a></li> +</ul> +</li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.config.training_config.html">training_config</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig"><code class="docutils literal notranslate"><span class="pre">TrainingConfig</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -536,7 +563,7 @@ <h1 id="api-graphnet-utilities-config--page-root">config<a class="headerlink" hr </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.utilities.config.model_config.html b/api/graphnet.utilities.config.model_config.html index 5c6fa0fbb..52531fc50 100644 --- a/api/graphnet.utilities.config.model_config.html +++ b/api/graphnet.utilities.config.model_config.html @@ -378,11 +378,81 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-utilities-config-model-config--page-root" class="md-nav__link">model_config</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ModelConfig</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.class_name" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">class_name</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.arguments" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">arguments</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.save_model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_model_config()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.model_config.ModelConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ModelConfig</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.model_config.ModelConfig.class_name" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">class_name</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.model_config.ModelConfig.arguments" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">arguments</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.model_config.ModelConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.model_config.ModelConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.model_config.ModelConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a> + + + </li></ul> + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.model_config.save_model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_model_config()</span></code></a> + + + </li></ul> + </li> <li class="md-nav__item"> @@ -465,7 +535,28 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-utilities-config-model-config--page-root" class="md-nav__link">model_config</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">ModelConfig</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.class_name" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">class_name</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.arguments" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">arguments</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.as_dict" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">as_dict()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.ModelConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a> + </li></ul> + </nav> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.model_config.save_model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">save_model_config()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -475,8 +566,88 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="model-config"> -<h1 id="api-graphnet-utilities-config-model-config--page-root">model_config<a class="headerlink" href="#api-graphnet-utilities-config-model-config--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.utilities.config.model_config"> +<span id="model-config"></span><h1 id="api-graphnet-utilities-config-model-config--page-root">model_config<a class="headerlink" href="#api-graphnet-utilities-config-model-config--page-root" title="Link to this heading">¶</a></h1> +<p>Config classes for the <cite>graphnet.models</cite> module.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.utilities.config.model_config.ModelConfig"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.model_config.</span></span><span class="sig-name descname"><span class="pre">ModelConfig</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">class_name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">arguments</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/model_config.html#ModelConfig"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.model_config.ModelConfig" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig" title="graphnet.utilities.config.base_config.BaseConfig"><code class="xref py py-class docutils literal notranslate"><span class="pre">BaseConfig</span></code></a></p> +<p>Configuration for all <a href="#id1"><span class="problematic" id="id2">`</span></a>Model`s.</p> +<p>Construct <cite>ModelConfig</cite>.</p> +<p>Can be used for model configuration as code, thereby making model +construction more transparent and reproducible. Note that this does +<em>not</em> save any trainable weights, meaning this is only a configuration +for the model’s hyperparameters. Any model instantiated from a +ModelConfig or file will be randomly initialised, and thus should be +trained.</p> +<p class="rubric">Examples</p> +<p>In one session, do:</p> +<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> +<span class="gp">>>> </span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dump</span><span class="p">()</span> +<span class="go">arguments:</span> +<span class="go"> - (...): (...)</span> +<span class="go">class_name: Model</span> +<span class="gp">>>> </span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="s2">"model.yml"</span><span class="p">)</span> +</pre></div> +</div> +<p>In another session, you can then do: +>>> model = Model.from_config(“model.yml”)</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>class_name</strong> (<em>str</em>) – </p></li> +<li><p><strong>arguments</strong> (<em>Dict</em><em>[</em><em>str</em><em>, </em><em>Any</em><em>]</em>) – </p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.model_config.ModelConfig.class_name"> +<span class="sig-name descname"><span class="pre">class_name</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code></em><a class="headerlink" href="#graphnet.utilities.config.model_config.ModelConfig.class_name" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.model_config.ModelConfig.arguments"> +<span class="sig-name descname"><span class="pre">arguments</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.model_config.ModelConfig.arguments" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py method"> +<dt class="sig sig-object py" id="graphnet.utilities.config.model_config.ModelConfig.as_dict"> +<span class="sig-name descname"><span class="pre">as_dict</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/model_config.html#ModelConfig.as_dict"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.model_config.ModelConfig.as_dict" title="Link to this definition">¶</a></dt> +<dd><p>Represent ModelConfig as a dict.</p> +<p>This builds on <cite>BaseModel.dict()</cite> but wraps the output in a single-key +dictionary to make it unambiguous to identify model arguments that are +themselves models.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code>]]</p> +</dd> +</dl> +</dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.model_config.ModelConfig.model_config"> +<span class="sig-name descname"><span class="pre">model_config</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[ConfigDict]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{}</span></em><a class="headerlink" href="#graphnet.utilities.config.model_config.ModelConfig.model_config" title="Link to this definition">¶</a></dt> +<dd><p>Configuration for the model, should be a dictionary conforming to [<cite>ConfigDict</cite>][pydantic.config.ConfigDict].</p> +</dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.model_config.ModelConfig.model_fields"> +<span class="sig-name descname"><span class="pre">model_fields</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[dict[str,</span> <span class="pre">FieldInfo]]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{'arguments':</span> <span class="pre">FieldInfo(annotation=Dict[str,</span> <span class="pre">Any],</span> <span class="pre">required=True),</span> <span class="pre">'class_name':</span> <span class="pre">FieldInfo(annotation=str,</span> <span class="pre">required=True)}</span></em><a class="headerlink" href="#graphnet.utilities.config.model_config.ModelConfig.model_fields" title="Link to this definition">¶</a></dt> +<dd><p>Metadata about the fields defined on the model, +mapping of field names to [<cite>FieldInfo</cite>][pydantic.fields.FieldInfo].</p> +<p>This replaces <cite>Model.__fields__</cite> from Pydantic V1.</p> +</dd></dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.utilities.config.model_config.save_model_config"> +<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.model_config.</span></span><span class="sig-name descname"><span class="pre">save_model_config</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">init_fn</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/model_config.html#save_model_config"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.model_config.save_model_config" title="Link to this definition">¶</a></dt> +<dd><p>Save the arguments to <cite>__init__</cite> functions as a member <cite>ModelConfig</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Callable</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>init_fn</strong> (<em>Callable</em>) – </p> +</dd> +</dl> +</dd></dl> </section> @@ -526,7 +697,7 @@ <h1 id="api-graphnet-utilities-config-model-config--page-root">model_config<a cl </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.utilities.config.parsing.html b/api/graphnet.utilities.config.parsing.html index 210d5f809..0e932473f 100644 --- a/api/graphnet.utilities.config.parsing.html +++ b/api/graphnet.utilities.config.parsing.html @@ -385,11 +385,70 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-utilities-config-parsing--page-root" class="md-nav__link">parsing</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.traverse_and_apply" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">traverse_and_apply()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.list_all_submodules" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">list_all_submodules()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.get_all_grapnet_classes" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_all_grapnet_classes()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.is_graphnet_module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">is_graphnet_module()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.is_graphnet_class" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">is_graphnet_class()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.get_graphnet_classes" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_graphnet_classes()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.parsing.traverse_and_apply" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">traverse_and_apply()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.parsing.list_all_submodules" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">list_all_submodules()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.parsing.get_all_grapnet_classes" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_all_grapnet_classes()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.parsing.is_graphnet_module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">is_graphnet_module()</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.parsing.is_graphnet_class" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">is_graphnet_class()</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.parsing.get_graphnet_classes" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_graphnet_classes()</span></code></a> + + + </li></ul> + </li> <li class="md-nav__item"> @@ -465,7 +524,24 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-utilities-config-parsing--page-root" class="md-nav__link">parsing</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.traverse_and_apply" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">traverse_and_apply()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.list_all_submodules" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">list_all_submodules()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.get_all_grapnet_classes" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_all_grapnet_classes()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.is_graphnet_module" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">is_graphnet_module()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.is_graphnet_class" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">is_graphnet_class()</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.parsing.get_graphnet_classes" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">get_graphnet_classes()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -475,8 +551,91 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="parsing"> -<h1 id="api-graphnet-utilities-config-parsing--page-root">parsing<a class="headerlink" href="#api-graphnet-utilities-config-parsing--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.utilities.config.parsing"> +<span id="parsing"></span><h1 id="api-graphnet-utilities-config-parsing--page-root">parsing<a class="headerlink" href="#api-graphnet-utilities-config-parsing--page-root" title="Link to this heading">¶</a></h1> +<p>Utility functions for parsing for using with Config-classes.</p> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.utilities.config.parsing.traverse_and_apply"> +<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.parsing.</span></span><span class="sig-name descname"><span class="pre">traverse_and_apply</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">obj</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">fn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">fn_kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/parsing.html#traverse_and_apply"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.parsing.traverse_and_apply" title="Link to this definition">¶</a></dt> +<dd><p>Apply <cite>fn</cite> to all elements in <cite>obj</cite>, resulting in same structure.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><ul class="simple"> +<li><p><strong>obj</strong> (<em>Any</em>) – </p></li> +<li><p><strong>fn</strong> (<em>Callable</em>) – </p></li> +<li><p><strong>fn_kwargs</strong> (<em>Dict</em><em>[</em><em>str</em><em>, </em><em>Any</em><em>] </em><em>| </em><em>None</em>) – </p></li> +</ul> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.utilities.config.parsing.list_all_submodules"> +<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.parsing.</span></span><span class="sig-name descname"><span class="pre">list_all_submodules</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">packages</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/parsing.html#list_all_submodules"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.parsing.list_all_submodules" title="Link to this definition">¶</a></dt> +<dd><p>List all submodules in <cite>packages</cite> recursively.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">ModuleType</span></code>]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>packages</strong> (<em>module</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.utilities.config.parsing.get_all_grapnet_classes"> +<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.parsing.</span></span><span class="sig-name descname"><span class="pre">get_all_grapnet_classes</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">packages</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/parsing.html#get_all_grapnet_classes"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.parsing.get_all_grapnet_classes" title="Link to this definition">¶</a></dt> +<dd><p>List all grapnet classes in <cite>packages</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">type</span></code>]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>packages</strong> (<em>module</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.utilities.config.parsing.is_graphnet_module"> +<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.parsing.</span></span><span class="sig-name descname"><span class="pre">is_graphnet_module</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">obj</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/parsing.html#is_graphnet_module"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.parsing.is_graphnet_module" title="Link to this definition">¶</a></dt> +<dd><p>Return whether <cite>obj</cite> is a module in graphnet.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>obj</strong> (<em>module</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.utilities.config.parsing.is_graphnet_class"> +<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.parsing.</span></span><span class="sig-name descname"><span class="pre">is_graphnet_class</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">obj</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/parsing.html#is_graphnet_class"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.parsing.is_graphnet_class" title="Link to this definition">¶</a></dt> +<dd><p>Return whether <cite>obj</cite> is a class in graphnet.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">bool</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>obj</strong> (<em>type</em>) – </p> +</dd> +</dl> +</dd></dl> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.utilities.config.parsing.get_graphnet_classes"> +<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.parsing.</span></span><span class="sig-name descname"><span class="pre">get_graphnet_classes</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">module</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/parsing.html#get_graphnet_classes"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.parsing.get_graphnet_classes" title="Link to this definition">¶</a></dt> +<dd><p>Return a lookup of all graphnet class names in <cite>module</cite>.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code>[<code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">type</span></code>]</p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>module</strong> (<em>module</em>) – </p> +</dd> +</dl> +</dd></dl> </section> @@ -526,7 +685,7 @@ <h1 id="api-graphnet-utilities-config-parsing--page-root">parsing<a class="heade </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.utilities.config.training_config.html b/api/graphnet.utilities.config.training_config.html index bdb49dee2..26774e9f2 100644 --- a/api/graphnet.utilities.config.training_config.html +++ b/api/graphnet.utilities.config.training_config.html @@ -392,11 +392,81 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-utilities-config-training-config--page-root" class="md-nav__link">training_config</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">TrainingConfig</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.target" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">target</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">early_stopping_patience</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.fit" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">fit</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">dataloader</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.training_config.TrainingConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">TrainingConfig</span></code></a> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.training_config.TrainingConfig.target" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">target</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">early_stopping_patience</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.training_config.TrainingConfig.fit" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">fit</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.training_config.TrainingConfig.dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">dataloader</span></code></a> + + + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.training_config.TrainingConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a> + </li> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.config.training_config.TrainingConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a> + + + </li></ul> + + </li></ul> + </li></ul> </li> @@ -465,7 +535,28 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-utilities-config-training-config--page-root" class="md-nav__link">training_config</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">TrainingConfig</span></code></a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.target" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">target</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">early_stopping_patience</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.fit" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">fit</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.dataloader" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">dataloader</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.model_config" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_config</span></code></a> + </li> + <li class="md-nav__item"><a href="#graphnet.utilities.config.training_config.TrainingConfig.model_fields" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">model_fields</span></code></a> + </li></ul> + </nav> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -475,8 +566,58 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="training-config"> -<h1 id="api-graphnet-utilities-config-training-config--page-root">training_config<a class="headerlink" href="#api-graphnet-utilities-config-training-config--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.utilities.config.training_config"> +<span id="training-config"></span><h1 id="api-graphnet-utilities-config-training-config--page-root">training_config<a class="headerlink" href="#api-graphnet-utilities-config-training-config--page-root" title="Link to this heading">¶</a></h1> +<p>Config classes for the <cite>graphnet.training</cite> module.</p> +<dl class="py class"> +<dt class="sig sig-object py" id="graphnet.utilities.config.training_config.TrainingConfig"> +<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">graphnet.utilities.config.training_config.</span></span><span class="sig-name descname"><span class="pre">TrainingConfig</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">early_stopping_patience</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">fit</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dataloader</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/config/training_config.html#TrainingConfig"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.config.training_config.TrainingConfig" title="Link to this definition">¶</a></dt> +<dd><p>Bases: <a class="reference internal" href="graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig" title="graphnet.utilities.config.base_config.BaseConfig"><code class="xref py py-class docutils literal notranslate"><span class="pre">BaseConfig</span></code></a></p> +<p>Configuration for all trainings.</p> +<p>Create a new model by parsing and validating input data from keyword arguments.</p> +<p>Raises [<cite>ValidationError</cite>][pydantic_core.ValidationError] if the input data cannot be +validated to form a valid model.</p> +<p><cite>__init__</cite> uses <cite>__pydantic_self__</cite> instead of the more common <cite>self</cite> for the first arg to +allow <cite>self</cite> as a field name.</p> +<dl class="field-list simple"> +<dt class="field-odd">Parameters<span class="colon">:</span></dt> +<dd class="field-odd"><ul class="simple"> +<li><p><strong>target</strong> (<em>str</em><em> | </em><em>List</em><em>[</em><em>str</em><em>]</em>) – </p></li> +<li><p><strong>early_stopping_patience</strong> (<em>int</em>) – </p></li> +<li><p><strong>fit</strong> (<em>Dict</em><em>[</em><em>str</em><em>, </em><em>Any</em><em>]</em>) – </p></li> +<li><p><strong>dataloader</strong> (<em>Dict</em><em>[</em><em>str</em><em>, </em><em>Any</em><em>]</em>) – </p></li> +</ul> +</dd> +</dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.training_config.TrainingConfig.target"> +<span class="sig-name descname"><span class="pre">target</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-data docutils literal notranslate"><span class="pre">Union</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-class docutils literal notranslate"><span class="pre">List</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">]]</span></em><a class="headerlink" href="#graphnet.utilities.config.training_config.TrainingConfig.target" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience"> +<span class="sig-name descname"><span class="pre">early_stopping_patience</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">int</span></code></em><a class="headerlink" href="#graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.training_config.TrainingConfig.fit"> +<span class="sig-name descname"><span class="pre">fit</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.training_config.TrainingConfig.fit" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.training_config.TrainingConfig.dataloader"> +<span class="sig-name descname"><span class="pre">dataloader</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><code class="xref py py-class docutils literal notranslate"><span class="pre">Dict</span></code><span class="pre">[</span><code class="xref py py-class docutils literal notranslate"><span class="pre">str</span></code><span class="pre">,</span> <code class="xref py py-data docutils literal notranslate"><span class="pre">Any</span></code><span class="pre">]</span></em><a class="headerlink" href="#graphnet.utilities.config.training_config.TrainingConfig.dataloader" title="Link to this definition">¶</a></dt> +<dd></dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.training_config.TrainingConfig.model_config"> +<span class="sig-name descname"><span class="pre">model_config</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[ConfigDict]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{}</span></em><a class="headerlink" href="#graphnet.utilities.config.training_config.TrainingConfig.model_config" title="Link to this definition">¶</a></dt> +<dd><p>Configuration for the model, should be a dictionary conforming to [<cite>ConfigDict</cite>][pydantic.config.ConfigDict].</p> +</dd></dl> +<dl class="py attribute"> +<dt class="sig sig-object py" id="graphnet.utilities.config.training_config.TrainingConfig.model_fields"> +<span class="sig-name descname"><span class="pre">model_fields</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">ClassVar[dict[str,</span> <span class="pre">FieldInfo]]</span></em><em class="property"><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="pre">{'dataloader':</span> <span class="pre">FieldInfo(annotation=Dict[str,</span> <span class="pre">Any],</span> <span class="pre">required=True),</span> <span class="pre">'early_stopping_patience':</span> <span class="pre">FieldInfo(annotation=int,</span> <span class="pre">required=True),</span> <span class="pre">'fit':</span> <span class="pre">FieldInfo(annotation=Dict[str,</span> <span class="pre">Any],</span> <span class="pre">required=True),</span> <span class="pre">'target':</span> <span class="pre">FieldInfo(annotation=Union[str,</span> <span class="pre">List[str]],</span> <span class="pre">required=True)}</span></em><a class="headerlink" href="#graphnet.utilities.config.training_config.TrainingConfig.model_fields" title="Link to this definition">¶</a></dt> +<dd><p>Metadata about the fields defined on the model, +mapping of field names to [<cite>FieldInfo</cite>][pydantic.fields.FieldInfo].</p> +<p>This replaces <cite>Model.__fields__</cite> from Pydantic V1.</p> +</dd></dl> +</dd></dl> </section> @@ -526,7 +667,7 @@ <h1 id="api-graphnet-utilities-config-training-config--page-root">training_confi </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.utilities.decorators.html b/api/graphnet.utilities.decorators.html index 0203d3c40..92ae0d5b7 100644 --- a/api/graphnet.utilities.decorators.html +++ b/api/graphnet.utilities.decorators.html @@ -484,7 +484,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.utilities.filesys.html b/api/graphnet.utilities.filesys.html index 8e0c99feb..f13e120c5 100644 --- a/api/graphnet.utilities.filesys.html +++ b/api/graphnet.utilities.filesys.html @@ -604,7 +604,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.utilities.html b/api/graphnet.utilities.html index 847defb5c..5723c23e5 100644 --- a/api/graphnet.utilities.html +++ b/api/graphnet.utilities.html @@ -476,7 +476,10 @@ <li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.logging.html#graphnet.utilities.logging.Logger"><code class="docutils literal notranslate"><span class="pre">Logger</span></code></a></li> </ul> </li> -<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.maths.html">maths</a></li> +<li class="toctree-l1"><a class="reference internal" href="graphnet.utilities.maths.html">maths</a><ul> +<li class="toctree-l2"><a class="reference internal" href="graphnet.utilities.maths.html#graphnet.utilities.maths.eps_like"><code class="docutils literal notranslate"><span class="pre">eps_like()</span></code></a></li> +</ul> +</li> </ul> </div> </section> @@ -528,7 +531,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.utilities.imports.html b/api/graphnet.utilities.imports.html index 7166dfe02..7ec9aa916 100644 --- a/api/graphnet.utilities.imports.html +++ b/api/graphnet.utilities.imports.html @@ -581,7 +581,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.utilities.logging.html b/api/graphnet.utilities.logging.html index a16ad7888..7972b4bad 100644 --- a/api/graphnet.utilities.logging.html +++ b/api/graphnet.utilities.logging.html @@ -831,7 +831,7 @@ </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/graphnet.utilities.maths.html b/api/graphnet.utilities.maths.html index 9fc848fe1..8eb8048fa 100644 --- a/api/graphnet.utilities.maths.html +++ b/api/graphnet.utilities.maths.html @@ -393,11 +393,25 @@ <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-utilities-maths--page-root" class="md-nav__link">maths</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.maths.eps_like" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">eps_like()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> + <ul class="md-nav__list"> + <li class="md-nav__item"> + + + <a href="#graphnet.utilities.maths.eps_like" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">eps_like()</span></code></a> + </li></ul> + </li></ul> </li> @@ -422,7 +436,14 @@ <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary"> + <label class="md-nav__title" for="__toc">"Contents"</label> <ul class="md-nav__list" data-md-scrollfix=""> + <li class="md-nav__item"><a href="#api-graphnet-utilities-maths--page-root" class="md-nav__link">maths</a><nav class="md-nav"> + <ul class="md-nav__list"> + <li class="md-nav__item"><a href="#graphnet.utilities.maths.eps_like" class="md-nav__link"><code class="docutils literal notranslate"><span class="pre">eps_like()</span></code></a> + </li></ul> + </nav> + </li> </ul> </nav> </div> @@ -432,8 +453,22 @@ <div class="md-content"> <article class="md-content__inner md-typeset" role="main"> - <section id="maths"> -<h1 id="api-graphnet-utilities-maths--page-root">maths<a class="headerlink" href="#api-graphnet-utilities-maths--page-root" title="Link to this heading">¶</a></h1> + <section id="module-graphnet.utilities.maths"> +<span id="maths"></span><h1 id="api-graphnet-utilities-maths--page-root">maths<a class="headerlink" href="#api-graphnet-utilities-maths--page-root" title="Link to this heading">¶</a></h1> +<p>Collection of assorted “maths-like” functions.</p> +<dl class="py function"> +<dt class="sig sig-object py" id="graphnet.utilities.maths.eps_like"> +<span class="sig-prename descclassname"><span class="pre">graphnet.utilities.maths.</span></span><span class="sig-name descname"><span class="pre">eps_like</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">tensor</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/graphnet/utilities/maths.html#eps_like"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#graphnet.utilities.maths.eps_like" title="Link to this definition">¶</a></dt> +<dd><p>Return <cite>eps</cite> matching <cite>tensor</cite>’s dtype.</p> +<dl class="field-list simple"> +<dt class="field-odd">Return type<span class="colon">:</span></dt> +<dd class="field-odd"><p><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></p> +</dd> +<dt class="field-even">Parameters<span class="colon">:</span></dt> +<dd class="field-even"><p><strong>tensor</strong> (<em>Tensor</em>) – </p> +</dd> +</dl> +</dd></dl> </section> @@ -483,7 +518,7 @@ <h1 id="api-graphnet-utilities-maths--page-root">maths<a class="headerlink" href </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/api/modules.html b/api/modules.html index 6f49ef2ab..0bc19ad98 100644 --- a/api/modules.html +++ b/api/modules.html @@ -362,7 +362,7 @@ <h1 id="api-modules--page-root">src<a class="headerlink" href="#api-modules--pag </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/contribute.html b/contribute.html index 0086e93f9..1a6ccae8d 100644 --- a/contribute.html +++ b/contribute.html @@ -484,7 +484,7 @@ <h2 id="code-quality">Code quality<a class="headerlink" href="#code-quality" tit </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/genindex.html b/genindex.html index 360c7a1ef..7ef7dc30a 100644 --- a/genindex.html +++ b/genindex.html @@ -339,23 +339,44 @@ <h1 id="index">Index</h1> | <a href="#N"><strong>N</strong></a> | <a href="#O"><strong>O</strong></a> | <a href="#P"><strong>P</strong></a> + | <a href="#Q"><strong>Q</strong></a> | <a href="#R"><strong>R</strong></a> | <a href="#S"><strong>S</strong></a> | <a href="#T"><strong>T</strong></a> | <a href="#U"><strong>U</strong></a> + | <a href="#V"><strong>V</strong></a> | <a href="#W"><strong>W</strong></a> + | <a href="#Z"><strong>Z</strong></a> </div> <h2 id="A">A</h2> <table style="width: 100%" class="indextable genindextable"><tr> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.add_label">add_label() (graphnet.data.dataset.dataset.Dataset method)</a> +</li> <li><a href="api/graphnet.data.sqlite.sqlite_dataconverter.html#graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.any_pulsemap_is_non_empty">any_pulsemap_is_non_empty() (graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter method)</a> </li> - </ul></td> - <td style="width: 33%; vertical-align: top;"><ul> <li><a href="api/graphnet.utilities.argparse.html#graphnet.utilities.argparse.ArgumentParser">ArgumentParser (class in graphnet.utilities.argparse)</a> </li> + <li><a href="api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig.arguments">arguments (graphnet.utilities.config.model_config.ModelConfig attribute)</a> +</li> + <li><a href="api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.as_dict">as_dict() (graphnet.utilities.config.base_config.BaseConfig method)</a> + + <ul> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.as_dict">(graphnet.utilities.config.dataset_config.DatasetConfig method)</a> +</li> + <li><a href="api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig.as_dict">(graphnet.utilities.config.model_config.ModelConfig method)</a> +</li> + </ul></li> + </ul></td> + <td style="width: 33%; vertical-align: top;"><ul> <li><a href="api/graphnet.data.sqlite.sqlite_utilities.html#graphnet.data.sqlite.sqlite_utilities.attach_index">attach_index() (in module graphnet.data.sqlite.sqlite_utilities)</a> +</li> + <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.AttributeCoarsening">AttributeCoarsening (class in graphnet.models.coarsening)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstruction">AzimuthReconstruction (class in graphnet.models.task.reconstruction)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa">AzimuthReconstructionWithKappa (class in graphnet.models.task.reconstruction)</a> </li> </ul></td> </tr></table> @@ -363,10 +384,20 @@ <h2 id="A">A</h2> <h2 id="B">B</h2> <table style="width: 100%" class="indextable genindextable"><tr> <td style="width: 33%; vertical-align: top;"><ul> - <li><a href="api/graphnet.training.weight_fitting.html#graphnet.training.weight_fitting.BjoernLow">BjoernLow (class in graphnet.training.weight_fitting)</a> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCMK.backward">backward() (graphnet.training.loss_functions.LogCMK static method)</a> +</li> + <li><a href="api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig">BaseConfig (class in graphnet.utilities.config.base_config)</a> +</li> + <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTask">BinaryClassificationTask (class in graphnet.models.task.classification)</a> </li> </ul></td> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTaskLogits">BinaryClassificationTaskLogits (class in graphnet.models.task.classification)</a> +</li> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.BinaryCrossEntropyLoss">BinaryCrossEntropyLoss (class in graphnet.training.loss_functions)</a> +</li> + <li><a href="api/graphnet.training.weight_fitting.html#graphnet.training.weight_fitting.BjoernLow">BjoernLow (class in graphnet.training.weight_fitting)</a> +</li> <li><a href="api/graphnet.data.extractors.utilities.types.html#graphnet.data.extractors.utilities.types.break_cyclic_recursion">break_cyclic_recursion() (in module graphnet.data.extractors.utilities.types)</a> </li> </ul></td> @@ -376,26 +407,62 @@ <h2 id="C">C</h2> <table style="width: 100%" class="indextable genindextable"><tr> <td style="width: 33%; vertical-align: top;"><ul> <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.cache_output_files">cache_output_files() (in module graphnet.data.dataconverter)</a> +</li> + <li><a href="api/graphnet.models.utils.html#graphnet.models.utils.calculate_distance_matrix">calculate_distance_matrix() (in module graphnet.models.utils)</a> +</li> + <li><a href="api/graphnet.models.utils.html#graphnet.models.utils.calculate_xyzt_homophily">calculate_xyzt_homophily() (in module graphnet.models.utils)</a> </li> <li><a href="api/graphnet.data.extractors.utilities.types.html#graphnet.data.extractors.utilities.types.cast_object_to_pure_python">cast_object_to_pure_python() (in module graphnet.data.extractors.utilities.types)</a> </li> <li><a href="api/graphnet.data.extractors.utilities.types.html#graphnet.data.extractors.utilities.types.cast_pulse_series_to_pure_python">cast_pulse_series_to_pure_python() (in module graphnet.data.extractors.utilities.types)</a> </li> - <li><a href="api/graphnet.pisa.fitting.html#graphnet.pisa.fitting.config_updater">config_updater() (in module graphnet.pisa.fitting)</a> + <li><a href="api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig.class_name">class_name (graphnet.utilities.config.model_config.ModelConfig attribute)</a> </li> - <li><a href="api/graphnet.data.sqlite.sqlite_dataconverter.html#graphnet.data.sqlite.sqlite_dataconverter.construct_dataframe">construct_dataframe() (in module graphnet.data.sqlite.sqlite_dataconverter)</a> + <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.Coarsening">Coarsening (class in graphnet.models.coarsening)</a> +</li> + <li><a href="api/graphnet.data.dataloader.html#graphnet.data.dataloader.collate_fn">collate_fn() (in module graphnet.data.dataloader)</a> + + <ul> + <li><a href="api/graphnet.training.utils.html#graphnet.training.utils.collate_fn">(in module graphnet.training.utils)</a> +</li> + </ul></li> + <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.ColumnMissingException">ColumnMissingException</a> +</li> + <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.compute_loss">compute_loss() (graphnet.models.standard_model.StandardModel method)</a> + + <ul> + <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task.compute_loss">(graphnet.models.task.task.Task method)</a> +</li> + </ul></li> + <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.concatenate">concatenate() (graphnet.data.dataset.dataset.Dataset class method)</a> </li> </ul></td> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable.config">config (graphnet.utilities.config.configurable.Configurable property)</a> +</li> + <li><a href="api/graphnet.pisa.fitting.html#graphnet.pisa.fitting.config_updater">config_updater() (in module graphnet.pisa.fitting)</a> +</li> + <li><a href="api/graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable">Configurable (class in graphnet.utilities.config.configurable)</a> +</li> + <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.configure_optimizers">configure_optimizers() (graphnet.models.standard_model.StandardModel method)</a> +</li> + <li><a href="api/graphnet.data.sqlite.sqlite_dataconverter.html#graphnet.data.sqlite.sqlite_dataconverter.construct_dataframe">construct_dataframe() (in module graphnet.data.sqlite.sqlite_dataconverter)</a> +</li> <li><a href="api/graphnet.utilities.argparse.html#graphnet.utilities.argparse.Options.contains">contains() (graphnet.utilities.argparse.Options method)</a> </li> <li><a href="api/graphnet.pisa.fitting.html#graphnet.pisa.fitting.ContourFitter">ContourFitter (class in graphnet.pisa.fitting)</a> +</li> + <li><a href="api/graphnet.models.gnn.convnet.html#graphnet.models.gnn.convnet.ConvNet">ConvNet (class in graphnet.models.gnn.convnet)</a> </li> <li><a href="api/graphnet.data.sqlite.sqlite_utilities.html#graphnet.data.sqlite.sqlite_utilities.create_table">create_table() (in module graphnet.data.sqlite.sqlite_utilities)</a> </li> <li><a href="api/graphnet.data.sqlite.sqlite_utilities.html#graphnet.data.sqlite.sqlite_utilities.create_table_and_save_to_sql">create_table_and_save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)</a> </li> <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.critical">critical() (graphnet.utilities.logging.Logger method)</a> +</li> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.CrossEntropyLoss">CrossEntropyLoss (class in graphnet.training.loss_functions)</a> +</li> + <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.CustomDOMCoarsening">CustomDOMCoarsening (class in graphnet.models.coarsening)</a> </li> </ul></td> </tr></table> @@ -409,8 +476,14 @@ <h2 id="D">D</h2> </li> <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.DataConverter">DataConverter (class in graphnet.data.dataconverter)</a> </li> - </ul></td> - <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.data.dataloader.html#graphnet.data.dataloader.DataLoader">DataLoader (class in graphnet.data.dataloader)</a> +</li> + <li><a href="api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig.dataloader">dataloader (graphnet.utilities.config.training_config.TrainingConfig attribute)</a> +</li> + <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset">Dataset (class in graphnet.data.dataset.dataset)</a> +</li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig">DatasetConfig (class in graphnet.utilities.config.dataset_config)</a> +</li> <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.debug">debug() (graphnet.utilities.logging.Logger method)</a> </li> <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.FEATURES.DEEPCORE">DEEPCORE (graphnet.data.constants.FEATURES attribute)</a> @@ -419,16 +492,130 @@ <h2 id="D">D</h2> <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.TRUTH.DEEPCORE">(graphnet.data.constants.TRUTH attribute)</a> </li> </ul></li> + <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels">default_prediction_labels (graphnet.models.task.classification.BinaryClassificationTask attribute)</a> + + <ul> + <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels">(graphnet.models.task.classification.BinaryClassificationTaskLogits attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels">(graphnet.models.task.reconstruction.AzimuthReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels">(graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels">(graphnet.models.task.reconstruction.DirectionReconstructionWithKappa attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels">(graphnet.models.task.reconstruction.EnergyReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels">(graphnet.models.task.reconstruction.EnergyReconstructionWithPower attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels">(graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels">(graphnet.models.task.reconstruction.InelasticityReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels">(graphnet.models.task.reconstruction.PositionReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels">(graphnet.models.task.reconstruction.TimeReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels">(graphnet.models.task.reconstruction.VertexReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels">(graphnet.models.task.reconstruction.ZenithReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels">(graphnet.models.task.reconstruction.ZenithReconstructionWithKappa attribute)</a> +</li> + <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.IdentityTask.default_prediction_labels">(graphnet.models.task.task.IdentityTask property)</a> +</li> + <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task.default_prediction_labels">(graphnet.models.task.task.Task property)</a> +</li> + </ul></li> + <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTask.default_target_labels">default_target_labels (graphnet.models.task.classification.BinaryClassificationTask attribute)</a> + + <ul> + <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels">(graphnet.models.task.classification.BinaryClassificationTaskLogits attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels">(graphnet.models.task.reconstruction.AzimuthReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels">(graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels">(graphnet.models.task.reconstruction.DirectionReconstructionWithKappa attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels">(graphnet.models.task.reconstruction.EnergyReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels">(graphnet.models.task.reconstruction.EnergyReconstructionWithPower attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels">(graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels">(graphnet.models.task.reconstruction.InelasticityReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels">(graphnet.models.task.reconstruction.PositionReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels">(graphnet.models.task.reconstruction.TimeReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels">(graphnet.models.task.reconstruction.VertexReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels">(graphnet.models.task.reconstruction.ZenithReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels">(graphnet.models.task.reconstruction.ZenithReconstructionWithKappa attribute)</a> +</li> + <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.IdentityTask.default_target_labels">(graphnet.models.task.task.IdentityTask property)</a> +</li> + <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task.default_target_labels">(graphnet.models.task.task.Task property)</a> +</li> + </ul></li> + </ul></td> + <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector">Detector (class in graphnet.models.detector.detector)</a> +</li> + <li><a href="api/graphnet.training.labels.html#graphnet.training.labels.Direction">Direction (class in graphnet.training.labels)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa">DirectionReconstructionWithKappa (class in graphnet.models.task.reconstruction)</a> +</li> + <li><a href="api/graphnet.data.dataloader.html#graphnet.data.dataloader.do_shuffle">do_shuffle() (in module graphnet.data.dataloader)</a> +</li> + <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.DOMAndTimeWindowCoarsening">DOMAndTimeWindowCoarsening (class in graphnet.models.coarsening)</a> +</li> + <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.DOMCoarsening">DOMCoarsening (class in graphnet.models.coarsening)</a> +</li> + <li><a href="api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.dump">dump() (graphnet.utilities.config.base_config.BaseConfig method)</a> +</li> + <li><a href="api/graphnet.models.gnn.dynedge.html#graphnet.models.gnn.dynedge.DynEdge">DynEdge (class in graphnet.models.gnn.dynedge)</a> +</li> + <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynEdgeConv">DynEdgeConv (class in graphnet.models.components.layers)</a> +</li> + <li><a href="api/graphnet.models.gnn.dynedge_jinst.html#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST">DynEdgeJINST (class in graphnet.models.gnn.dynedge_jinst)</a> +</li> + <li><a href="api/graphnet.models.gnn.dynedge_kaggle_tito.html#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO">DynEdgeTITO (class in graphnet.models.gnn.dynedge_kaggle_tito)</a> +</li> + <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynTrans">DynTrans (class in graphnet.models.components.layers)</a> +</li> </ul></td> </tr></table> <h2 id="E">E</h2> <table style="width: 100%" class="indextable genindextable"><tr> <td style="width: 33%; vertical-align: top;"><ul> - <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.error">error() (graphnet.utilities.logging.Logger method)</a> + <li><a href="api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience">early_stopping_patience (graphnet.utilities.config.training_config.TrainingConfig attribute)</a> +</li> + <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito">EdgeConvTito (class in graphnet.models.components.layers)</a> +</li> + <li><a href="api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EdgeDefinition">EdgeDefinition (class in graphnet.models.graphs.edges.edges)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstruction">EnergyReconstruction (class in graphnet.models.task.reconstruction)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithPower">EnergyReconstructionWithPower (class in graphnet.models.task.reconstruction)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty">EnergyReconstructionWithUncertainty (class in graphnet.models.task.reconstruction)</a> </li> </ul></td> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.EnsembleDataset">EnsembleDataset (class in graphnet.data.dataset.dataset)</a> +</li> + <li><a href="api/graphnet.utilities.maths.html#graphnet.utilities.maths.eps_like">eps_like() (in module graphnet.utilities.maths)</a> +</li> + <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.error">error() (graphnet.utilities.logging.Logger method)</a> +</li> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.EuclideanDistanceLoss">EuclideanDistanceLoss (class in graphnet.training.loss_functions)</a> +</li> + <li><a href="api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EuclideanEdges">EuclideanEdges (class in graphnet.models.graphs.edges.edges)</a> +</li> <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.DataConverter.execute">execute() (graphnet.data.dataconverter.DataConverter method)</a> </li> </ul></td> @@ -437,7 +624,23 @@ <h2 id="E">E</h2> <h2 id="F">F</h2> <table style="width: 100%" class="indextable genindextable"><tr> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector.feature_map">feature_map() (graphnet.models.detector.detector.Detector method)</a> + + <ul> + <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCube86.feature_map">(graphnet.models.detector.icecube.IceCube86 method)</a> +</li> + <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeDeepCore.feature_map">(graphnet.models.detector.icecube.IceCubeDeepCore method)</a> +</li> + <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeKaggle.feature_map">(graphnet.models.detector.icecube.IceCubeKaggle method)</a> +</li> + <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeUpgrade.feature_map">(graphnet.models.detector.icecube.IceCubeUpgrade method)</a> +</li> + <li><a href="api/graphnet.models.detector.prometheus.html#graphnet.models.detector.prometheus.Prometheus.feature_map">(graphnet.models.detector.prometheus.Prometheus method)</a> +</li> + </ul></li> <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.FEATURES">FEATURES (class in graphnet.data.constants)</a> +</li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.features">features (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> </li> <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.file_handlers">file_handlers (graphnet.utilities.logging.Logger property)</a> </li> @@ -453,12 +656,16 @@ <h2 id="F">F</h2> </li> <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.RepeatFilter.filter">filter() (graphnet.utilities.logging.RepeatFilter method)</a> </li> - </ul></td> - <td style="width: 33%; vertical-align: top;"><ul> <li><a href="api/graphnet.utilities.filesys.html#graphnet.utilities.filesys.find_i3_files">find_i3_files() (in module graphnet.utilities.filesys)</a> </li> - <li><a href="api/graphnet.training.weight_fitting.html#graphnet.training.weight_fitting.WeightFitter.fit">fit() (graphnet.training.weight_fitting.WeightFitter method)</a> + <li><a href="api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig.fit">fit (graphnet.utilities.config.training_config.TrainingConfig attribute)</a> </li> + <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.fit">fit() (graphnet.models.model.Model method)</a> + + <ul> + <li><a href="api/graphnet.training.weight_fitting.html#graphnet.training.weight_fitting.WeightFitter.fit">(graphnet.training.weight_fitting.WeightFitter method)</a> +</li> + </ul></li> <li><a href="api/graphnet.pisa.fitting.html#graphnet.pisa.fitting.ContourFitter.fit_1d_contour">fit_1d_contour() (graphnet.pisa.fitting.ContourFitter method)</a> </li> <li><a href="api/graphnet.pisa.fitting.html#graphnet.pisa.fitting.ContourFitter.fit_2d_contour">fit_2d_contour() (graphnet.pisa.fitting.ContourFitter method)</a> @@ -467,9 +674,59 @@ <h2 id="F">F</h2> </li> <li><a href="api/graphnet.data.extractors.utilities.collections.html#graphnet.data.extractors.utilities.collections.flatten_nested_dictionary">flatten_nested_dictionary() (in module graphnet.data.extractors.utilities.collections)</a> </li> + <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.Coarsening.forward">forward() (graphnet.models.coarsening.Coarsening method)</a> + + <ul> + <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynEdgeConv.forward">(graphnet.models.components.layers.DynEdgeConv method)</a> +</li> + <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.DynTrans.forward">(graphnet.models.components.layers.DynTrans method)</a> +</li> + <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito.forward">(graphnet.models.components.layers.EdgeConvTito method)</a> +</li> + <li><a href="api/graphnet.models.detector.detector.html#graphnet.models.detector.detector.Detector.forward">(graphnet.models.detector.detector.Detector method)</a> +</li> + <li><a href="api/graphnet.models.gnn.convnet.html#graphnet.models.gnn.convnet.ConvNet.forward">(graphnet.models.gnn.convnet.ConvNet method)</a> +</li> + <li><a href="api/graphnet.models.gnn.dynedge.html#graphnet.models.gnn.dynedge.DynEdge.forward">(graphnet.models.gnn.dynedge.DynEdge method)</a> +</li> + <li><a href="api/graphnet.models.gnn.dynedge_jinst.html#graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward">(graphnet.models.gnn.dynedge_jinst.DynEdgeJINST method)</a> +</li> + <li><a href="api/graphnet.models.gnn.dynedge_kaggle_tito.html#graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward">(graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO method)</a> +</li> + <li><a href="api/graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN.forward">(graphnet.models.gnn.gnn.GNN method)</a> +</li> + <li><a href="api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.EdgeDefinition.forward">(graphnet.models.graphs.edges.edges.EdgeDefinition method)</a> +</li> + <li><a href="api/graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition.forward">(graphnet.models.graphs.graph_definition.GraphDefinition method)</a> +</li> + <li><a href="api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition.forward">(graphnet.models.graphs.nodes.nodes.NodeDefinition method)</a> +</li> + <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.forward">(graphnet.models.model.Model method)</a> +</li> + <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.forward">(graphnet.models.standard_model.StandardModel method)</a> +</li> + <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task.forward">(graphnet.models.task.task.Task method)</a> +</li> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCMK.forward">(graphnet.training.loss_functions.LogCMK static method)</a> +</li> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction.forward">(graphnet.training.loss_functions.LossFunction method)</a> +</li> + </ul></li> + </ul></td> + <td style="width: 33%; vertical-align: top;"><ul> <li><a href="api/graphnet.data.extractors.utilities.frames.html#graphnet.data.extractors.utilities.frames.frame_is_montecarlo">frame_is_montecarlo() (in module graphnet.data.extractors.utilities.frames)</a> </li> <li><a href="api/graphnet.data.extractors.utilities.frames.html#graphnet.data.extractors.utilities.frames.frame_is_noise">frame_is_noise() (in module graphnet.data.extractors.utilities.frames)</a> +</li> + <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.from_config">from_config() (graphnet.data.dataset.dataset.Dataset class method)</a> + + <ul> + <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.from_config">(graphnet.models.model.Model class method)</a> +</li> + <li><a href="api/graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable.from_config">(graphnet.utilities.config.configurable.Configurable class method)</a> +</li> + </ul></li> + <li><a href="api/graphnet.data.dataloader.html#graphnet.data.dataloader.DataLoader.from_dataset_config">from_dataset_config() (graphnet.data.dataloader.DataLoader class method)</a> </li> </ul></td> </tr></table> @@ -478,12 +735,30 @@ <h2 id="G">G</h2> <table style="width: 100%" class="indextable genindextable"><tr> <td style="width: 33%; vertical-align: top;"><ul> <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.FileSet.gcd_file">gcd_file (graphnet.data.dataconverter.FileSet attribute)</a> +</li> + <li><a href="api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.get_all_argument_values">get_all_argument_values() (in module graphnet.utilities.config.base_config)</a> +</li> + <li><a href="api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.get_all_grapnet_classes">get_all_grapnet_classes() (in module graphnet.utilities.config.parsing)</a> +</li> + <li><a href="api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.get_graphnet_classes">get_graphnet_classes() (in module graphnet.utilities.config.parsing)</a> +</li> + <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.PiecewiseLinearLR.get_lr">get_lr() (graphnet.training.callbacks.PiecewiseLinearLR method)</a> </li> <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.DataConverter.get_map_function">get_map_function() (graphnet.data.dataconverter.DataConverter method)</a> </li> <li><a href="api/graphnet.data.extractors.utilities.types.html#graphnet.data.extractors.utilities.types.get_member_variables">get_member_variables() (in module graphnet.data.extractors.utilities.types)</a> +</li> + <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.get_metrics">get_metrics() (graphnet.training.callbacks.ProgressBar method)</a> </li> <li><a href="api/graphnet.data.extractors.utilities.frames.html#graphnet.data.extractors.utilities.frames.get_om_keys_and_pulseseries">get_om_keys_and_pulseseries() (in module graphnet.data.extractors.utilities.frames)</a> +</li> + <li><a href="api/graphnet.training.utils.html#graphnet.training.utils.get_predictions">get_predictions() (in module graphnet.training.utils)</a> +</li> + <li><a href="api/graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN">GNN (class in graphnet.models.gnn.gnn)</a> +</li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition">graph_definition (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> +</li> + <li><a href="api/graphnet.models.graphs.graph_definition.html#graphnet.models.graphs.graph_definition.GraphDefinition">GraphDefinition (class in graphnet.models.graphs.graph_definition)</a> </li> <li> graphnet @@ -518,6 +793,62 @@ <h2 id="G">G</h2> <ul> <li><a href="api/graphnet.data.dataconverter.html#module-graphnet.data.dataconverter">module</a> +</li> + </ul></li> + <li> + graphnet.data.dataloader + + <ul> + <li><a href="api/graphnet.data.dataloader.html#module-graphnet.data.dataloader">module</a> +</li> + </ul></li> + <li> + graphnet.data.dataset + + <ul> + <li><a href="api/graphnet.data.dataset.html#module-graphnet.data.dataset">module</a> +</li> + </ul></li> + <li> + graphnet.data.dataset.dataset + + <ul> + <li><a href="api/graphnet.data.dataset.dataset.html#module-graphnet.data.dataset.dataset">module</a> +</li> + </ul></li> + <li> + graphnet.data.dataset.parquet + + <ul> + <li><a href="api/graphnet.data.dataset.parquet.html#module-graphnet.data.dataset.parquet">module</a> +</li> + </ul></li> + <li> + graphnet.data.dataset.parquet.parquet_dataset + + <ul> + <li><a href="api/graphnet.data.dataset.parquet.parquet_dataset.html#module-graphnet.data.dataset.parquet.parquet_dataset">module</a> +</li> + </ul></li> + <li> + graphnet.data.dataset.sqlite + + <ul> + <li><a href="api/graphnet.data.dataset.sqlite.html#module-graphnet.data.dataset.sqlite">module</a> +</li> + </ul></li> + <li> + graphnet.data.dataset.sqlite.sqlite_dataset + + <ul> + <li><a href="api/graphnet.data.dataset.sqlite.sqlite_dataset.html#module-graphnet.data.dataset.sqlite.sqlite_dataset">module</a> +</li> + </ul></li> + <li> + graphnet.data.dataset.sqlite.sqlite_dataset_perturbed + + <ul> + <li><a href="api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html#module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed">module</a> </li> </ul></li> <li> @@ -632,8 +963,6 @@ <h2 id="G">G</h2> <li><a href="api/graphnet.data.extractors.utilities.frames.html#module-graphnet.data.extractors.utilities.frames">module</a> </li> </ul></li> - </ul></td> - <td style="width: 33%; vertical-align: top;"><ul> <li> graphnet.data.extractors.utilities.types @@ -653,6 +982,13 @@ <h2 id="G">G</h2> <ul> <li><a href="api/graphnet.data.parquet.parquet_dataconverter.html#module-graphnet.data.parquet.parquet_dataconverter">module</a> +</li> + </ul></li> + <li> + graphnet.data.pipeline + + <ul> + <li><a href="api/graphnet.data.pipeline.html#module-graphnet.data.pipeline">module</a> </li> </ul></li> <li> @@ -712,95 +1048,399 @@ <h2 id="G">G</h2> </li> </ul></li> <li> - graphnet.pisa + graphnet.deployment.i3modules.graphnet_module <ul> - <li><a href="api/graphnet.pisa.html#module-graphnet.pisa">module</a> + <li><a href="api/graphnet.deployment.i3modules.graphnet_module.html#module-graphnet.deployment.i3modules.graphnet_module">module</a> </li> </ul></li> <li> - graphnet.pisa.fitting + graphnet.models <ul> - <li><a href="api/graphnet.pisa.fitting.html#module-graphnet.pisa.fitting">module</a> + <li><a href="api/graphnet.models.html#module-graphnet.models">module</a> </li> </ul></li> <li> - graphnet.pisa.plotting + graphnet.models.coarsening <ul> - <li><a href="api/graphnet.pisa.plotting.html#module-graphnet.pisa.plotting">module</a> + <li><a href="api/graphnet.models.coarsening.html#module-graphnet.models.coarsening">module</a> </li> </ul></li> <li> - graphnet.training + graphnet.models.components <ul> - <li><a href="api/graphnet.training.html#module-graphnet.training">module</a> + <li><a href="api/graphnet.models.components.html#module-graphnet.models.components">module</a> </li> </ul></li> + </ul></td> + <td style="width: 33%; vertical-align: top;"><ul> <li> - graphnet.training.weight_fitting + graphnet.models.components.layers <ul> - <li><a href="api/graphnet.training.weight_fitting.html#module-graphnet.training.weight_fitting">module</a> + <li><a href="api/graphnet.models.components.layers.html#module-graphnet.models.components.layers">module</a> </li> </ul></li> <li> - graphnet.utilities + graphnet.models.components.pool <ul> - <li><a href="api/graphnet.utilities.html#module-graphnet.utilities">module</a> + <li><a href="api/graphnet.models.components.pool.html#module-graphnet.models.components.pool">module</a> </li> </ul></li> <li> - graphnet.utilities.argparse + graphnet.models.detector <ul> - <li><a href="api/graphnet.utilities.argparse.html#module-graphnet.utilities.argparse">module</a> + <li><a href="api/graphnet.models.detector.html#module-graphnet.models.detector">module</a> </li> </ul></li> <li> - graphnet.utilities.decorators + graphnet.models.detector.detector <ul> - <li><a href="api/graphnet.utilities.decorators.html#module-graphnet.utilities.decorators">module</a> + <li><a href="api/graphnet.models.detector.detector.html#module-graphnet.models.detector.detector">module</a> </li> </ul></li> <li> - graphnet.utilities.filesys + graphnet.models.detector.icecube <ul> - <li><a href="api/graphnet.utilities.filesys.html#module-graphnet.utilities.filesys">module</a> + <li><a href="api/graphnet.models.detector.icecube.html#module-graphnet.models.detector.icecube">module</a> </li> </ul></li> <li> - graphnet.utilities.imports + graphnet.models.detector.prometheus <ul> - <li><a href="api/graphnet.utilities.imports.html#module-graphnet.utilities.imports">module</a> + <li><a href="api/graphnet.models.detector.prometheus.html#module-graphnet.models.detector.prometheus">module</a> </li> </ul></li> <li> - graphnet.utilities.logging + graphnet.models.gnn <ul> - <li><a href="api/graphnet.utilities.logging.html#module-graphnet.utilities.logging">module</a> + <li><a href="api/graphnet.models.gnn.html#module-graphnet.models.gnn">module</a> </li> </ul></li> - </ul></td> -</tr></table> + <li> + graphnet.models.gnn.convnet -<h2 id="H">H</h2> -<table style="width: 100%" class="indextable genindextable"><tr> - <td style="width: 33%; vertical-align: top;"><ul> - <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.handlers">handlers (graphnet.utilities.logging.Logger property)</a> + <ul> + <li><a href="api/graphnet.models.gnn.convnet.html#module-graphnet.models.gnn.convnet">module</a> </li> - <li><a href="api/graphnet.utilities.filesys.html#graphnet.utilities.filesys.has_extension">has_extension() (in module graphnet.utilities.filesys)</a> + </ul></li> + <li> + graphnet.models.gnn.dynedge + + <ul> + <li><a href="api/graphnet.models.gnn.dynedge.html#module-graphnet.models.gnn.dynedge">module</a> </li> - </ul></td> - <td style="width: 33%; vertical-align: top;"><ul> - <li><a href="api/graphnet.utilities.imports.html#graphnet.utilities.imports.has_icecube_package">has_icecube_package() (in module graphnet.utilities.imports)</a> + </ul></li> + <li> + graphnet.models.gnn.dynedge_jinst + + <ul> + <li><a href="api/graphnet.models.gnn.dynedge_jinst.html#module-graphnet.models.gnn.dynedge_jinst">module</a> +</li> + </ul></li> + <li> + graphnet.models.gnn.dynedge_kaggle_tito + + <ul> + <li><a href="api/graphnet.models.gnn.dynedge_kaggle_tito.html#module-graphnet.models.gnn.dynedge_kaggle_tito">module</a> +</li> + </ul></li> + <li> + graphnet.models.gnn.gnn + + <ul> + <li><a href="api/graphnet.models.gnn.gnn.html#module-graphnet.models.gnn.gnn">module</a> +</li> + </ul></li> + <li> + graphnet.models.graphs + + <ul> + <li><a href="api/graphnet.models.graphs.html#module-graphnet.models.graphs">module</a> +</li> + </ul></li> + <li> + graphnet.models.graphs.edges + + <ul> + <li><a href="api/graphnet.models.graphs.edges.html#module-graphnet.models.graphs.edges">module</a> +</li> + </ul></li> + <li> + graphnet.models.graphs.edges.edges + + <ul> + <li><a href="api/graphnet.models.graphs.edges.edges.html#module-graphnet.models.graphs.edges.edges">module</a> +</li> + </ul></li> + <li> + graphnet.models.graphs.graph_definition + + <ul> + <li><a href="api/graphnet.models.graphs.graph_definition.html#module-graphnet.models.graphs.graph_definition">module</a> +</li> + </ul></li> + <li> + graphnet.models.graphs.graphs + + <ul> + <li><a href="api/graphnet.models.graphs.graphs.html#module-graphnet.models.graphs.graphs">module</a> +</li> + </ul></li> + <li> + graphnet.models.graphs.nodes + + <ul> + <li><a href="api/graphnet.models.graphs.nodes.html#module-graphnet.models.graphs.nodes">module</a> +</li> + </ul></li> + <li> + graphnet.models.graphs.nodes.nodes + + <ul> + <li><a href="api/graphnet.models.graphs.nodes.nodes.html#module-graphnet.models.graphs.nodes.nodes">module</a> +</li> + </ul></li> + <li> + graphnet.models.model + + <ul> + <li><a href="api/graphnet.models.model.html#module-graphnet.models.model">module</a> +</li> + </ul></li> + <li> + graphnet.models.standard_model + + <ul> + <li><a href="api/graphnet.models.standard_model.html#module-graphnet.models.standard_model">module</a> +</li> + </ul></li> + <li> + graphnet.models.task + + <ul> + <li><a href="api/graphnet.models.task.html#module-graphnet.models.task">module</a> +</li> + </ul></li> + <li> + graphnet.models.task.classification + + <ul> + <li><a href="api/graphnet.models.task.classification.html#module-graphnet.models.task.classification">module</a> +</li> + </ul></li> + <li> + graphnet.models.task.reconstruction + + <ul> + <li><a href="api/graphnet.models.task.reconstruction.html#module-graphnet.models.task.reconstruction">module</a> +</li> + </ul></li> + <li> + graphnet.models.task.task + + <ul> + <li><a href="api/graphnet.models.task.task.html#module-graphnet.models.task.task">module</a> +</li> + </ul></li> + <li> + graphnet.models.utils + + <ul> + <li><a href="api/graphnet.models.utils.html#module-graphnet.models.utils">module</a> +</li> + </ul></li> + <li> + graphnet.pisa + + <ul> + <li><a href="api/graphnet.pisa.html#module-graphnet.pisa">module</a> +</li> + </ul></li> + <li> + graphnet.pisa.fitting + + <ul> + <li><a href="api/graphnet.pisa.fitting.html#module-graphnet.pisa.fitting">module</a> +</li> + </ul></li> + <li> + graphnet.pisa.plotting + + <ul> + <li><a href="api/graphnet.pisa.plotting.html#module-graphnet.pisa.plotting">module</a> +</li> + </ul></li> + <li> + graphnet.training + + <ul> + <li><a href="api/graphnet.training.html#module-graphnet.training">module</a> +</li> + </ul></li> + <li> + graphnet.training.callbacks + + <ul> + <li><a href="api/graphnet.training.callbacks.html#module-graphnet.training.callbacks">module</a> +</li> + </ul></li> + <li> + graphnet.training.labels + + <ul> + <li><a href="api/graphnet.training.labels.html#module-graphnet.training.labels">module</a> +</li> + </ul></li> + <li> + graphnet.training.loss_functions + + <ul> + <li><a href="api/graphnet.training.loss_functions.html#module-graphnet.training.loss_functions">module</a> +</li> + </ul></li> + <li> + graphnet.training.utils + + <ul> + <li><a href="api/graphnet.training.utils.html#module-graphnet.training.utils">module</a> +</li> + </ul></li> + <li> + graphnet.training.weight_fitting + + <ul> + <li><a href="api/graphnet.training.weight_fitting.html#module-graphnet.training.weight_fitting">module</a> +</li> + </ul></li> + <li> + graphnet.utilities + + <ul> + <li><a href="api/graphnet.utilities.html#module-graphnet.utilities">module</a> +</li> + </ul></li> + <li> + graphnet.utilities.argparse + + <ul> + <li><a href="api/graphnet.utilities.argparse.html#module-graphnet.utilities.argparse">module</a> +</li> + </ul></li> + <li> + graphnet.utilities.config + + <ul> + <li><a href="api/graphnet.utilities.config.html#module-graphnet.utilities.config">module</a> +</li> + </ul></li> + <li> + graphnet.utilities.config.base_config + + <ul> + <li><a href="api/graphnet.utilities.config.base_config.html#module-graphnet.utilities.config.base_config">module</a> +</li> + </ul></li> + <li> + graphnet.utilities.config.configurable + + <ul> + <li><a href="api/graphnet.utilities.config.configurable.html#module-graphnet.utilities.config.configurable">module</a> +</li> + </ul></li> + <li> + graphnet.utilities.config.dataset_config + + <ul> + <li><a href="api/graphnet.utilities.config.dataset_config.html#module-graphnet.utilities.config.dataset_config">module</a> +</li> + </ul></li> + <li> + graphnet.utilities.config.model_config + + <ul> + <li><a href="api/graphnet.utilities.config.model_config.html#module-graphnet.utilities.config.model_config">module</a> +</li> + </ul></li> + <li> + graphnet.utilities.config.parsing + + <ul> + <li><a href="api/graphnet.utilities.config.parsing.html#module-graphnet.utilities.config.parsing">module</a> +</li> + </ul></li> + <li> + graphnet.utilities.config.training_config + + <ul> + <li><a href="api/graphnet.utilities.config.training_config.html#module-graphnet.utilities.config.training_config">module</a> +</li> + </ul></li> + <li> + graphnet.utilities.decorators + + <ul> + <li><a href="api/graphnet.utilities.decorators.html#module-graphnet.utilities.decorators">module</a> +</li> + </ul></li> + <li> + graphnet.utilities.filesys + + <ul> + <li><a href="api/graphnet.utilities.filesys.html#module-graphnet.utilities.filesys">module</a> +</li> + </ul></li> + <li> + graphnet.utilities.imports + + <ul> + <li><a href="api/graphnet.utilities.imports.html#module-graphnet.utilities.imports">module</a> +</li> + </ul></li> + <li> + graphnet.utilities.logging + + <ul> + <li><a href="api/graphnet.utilities.logging.html#module-graphnet.utilities.logging">module</a> +</li> + </ul></li> + <li> + graphnet.utilities.maths + + <ul> + <li><a href="api/graphnet.utilities.maths.html#module-graphnet.utilities.maths">module</a> +</li> + </ul></li> + <li><a href="api/graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module">GraphNeTI3Module (class in graphnet.deployment.i3modules.graphnet_module)</a> +</li> + <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.group_by">group_by() (in module graphnet.models.components.pool)</a> +</li> + <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.group_pulses_to_dom">group_pulses_to_dom() (in module graphnet.models.components.pool)</a> +</li> + <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.group_pulses_to_pmt">group_pulses_to_pmt() (in module graphnet.models.components.pool)</a> +</li> + </ul></td> +</tr></table> + +<h2 id="H">H</h2> +<table style="width: 100%" class="indextable genindextable"><tr> + <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.handlers">handlers (graphnet.utilities.logging.Logger property)</a> +</li> + <li><a href="api/graphnet.utilities.filesys.html#graphnet.utilities.filesys.has_extension">has_extension() (in module graphnet.utilities.filesys)</a> +</li> + </ul></td> + <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.utilities.imports.html#graphnet.utilities.imports.has_icecube_package">has_icecube_package() (in module graphnet.utilities.imports)</a> </li> <li><a href="api/graphnet.utilities.imports.html#graphnet.utilities.imports.has_pisa_package">has_pisa_package() (in module graphnet.utilities.imports)</a> </li> @@ -829,12 +1469,16 @@ <h2 id="I">I</h2> <li><a href="api/graphnet.data.extractors.i3hybridrecoextractor.html#graphnet.data.extractors.i3hybridrecoextractor.I3GalacticPlaneHybridRecoExtractor">I3GalacticPlaneHybridRecoExtractor (class in graphnet.data.extractors.i3hybridrecoextractor)</a> </li> <li><a href="api/graphnet.data.extractors.i3genericextractor.html#graphnet.data.extractors.i3genericextractor.I3GenericExtractor">I3GenericExtractor (class in graphnet.data.extractors.i3genericextractor)</a> +</li> + <li><a href="api/graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.I3InferenceModule">I3InferenceModule (class in graphnet.deployment.i3modules.graphnet_module)</a> </li> <li><a href="api/graphnet.data.extractors.i3ntmuonlabelsextractor.html#graphnet.data.extractors.i3ntmuonlabelsextractor.I3NTMuonLabelExtractor">I3NTMuonLabelExtractor (class in graphnet.data.extractors.i3ntmuonlabelsextractor)</a> </li> <li><a href="api/graphnet.data.extractors.i3particleextractor.html#graphnet.data.extractors.i3particleextractor.I3ParticleExtractor">I3ParticleExtractor (class in graphnet.data.extractors.i3particleextractor)</a> </li> <li><a href="api/graphnet.data.extractors.i3pisaextractor.html#graphnet.data.extractors.i3pisaextractor.I3PISAExtractor">I3PISAExtractor (class in graphnet.data.extractors.i3pisaextractor)</a> +</li> + <li><a href="api/graphnet.deployment.i3modules.graphnet_module.html#graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule">I3PulseCleanerModule (class in graphnet.deployment.i3modules.graphnet_module)</a> </li> <li><a href="api/graphnet.data.extractors.i3featureextractor.html#graphnet.data.extractors.i3featureextractor.I3PulseNoiseTruthFlagIceCubeUpgrade">I3PulseNoiseTruthFlagIceCubeUpgrade (class in graphnet.data.extractors.i3featureextractor)</a> </li> @@ -842,29 +1486,63 @@ <h2 id="I">I</h2> </li> <li><a href="api/graphnet.data.extractors.i3retroextractor.html#graphnet.data.extractors.i3retroextractor.I3RetroExtractor">I3RetroExtractor (class in graphnet.data.extractors.i3retroextractor)</a> </li> - </ul></td> - <td style="width: 33%; vertical-align: top;"><ul> <li><a href="api/graphnet.data.extractors.i3splinempeextractor.html#graphnet.data.extractors.i3splinempeextractor.I3SplineMPEICExtractor">I3SplineMPEICExtractor (class in graphnet.data.extractors.i3splinempeextractor)</a> </li> <li><a href="api/graphnet.data.extractors.i3truthextractor.html#graphnet.data.extractors.i3truthextractor.I3TruthExtractor">I3TruthExtractor (class in graphnet.data.extractors.i3truthextractor)</a> </li> <li><a href="api/graphnet.data.extractors.i3tumextractor.html#graphnet.data.extractors.i3tumextractor.I3TUMExtractor">I3TUMExtractor (class in graphnet.data.extractors.i3tumextractor)</a> +</li> + <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCube86">IceCube86 (class in graphnet.models.detector.icecube)</a> </li> <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.FEATURES.ICECUBE86">ICECUBE86 (graphnet.data.constants.FEATURES attribute)</a> <ul> <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.TRUTH.ICECUBE86">(graphnet.data.constants.TRUTH attribute)</a> +</li> + </ul></li> + <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeDeepCore">IceCubeDeepCore (class in graphnet.models.detector.icecube)</a> +</li> + </ul></td> + <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeKaggle">IceCubeKaggle (class in graphnet.models.detector.icecube)</a> +</li> + <li><a href="api/graphnet.models.detector.icecube.html#graphnet.models.detector.icecube.IceCubeUpgrade">IceCubeUpgrade (class in graphnet.models.detector.icecube)</a> +</li> + <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.IdentityTask">IdentityTask (class in graphnet.models.task.task)</a> +</li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.index_column">index_column (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.InelasticityReconstruction">InelasticityReconstruction (class in graphnet.models.task.reconstruction)</a> +</li> + <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.inference">inference() (graphnet.models.standard_model.StandardModel method)</a> + + <ul> + <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task.inference">(graphnet.models.task.task.Task method)</a> </li> </ul></li> <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.info">info() (graphnet.utilities.logging.Logger method)</a> </li> <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.init_global_index">init_global_index() (in module graphnet.data.dataconverter)</a> +</li> + <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_predict_tqdm">init_predict_tqdm() (graphnet.training.callbacks.ProgressBar method)</a> +</li> + <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_test_tqdm">init_test_tqdm() (graphnet.training.callbacks.ProgressBar method)</a> +</li> + <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_train_tqdm">init_train_tqdm() (graphnet.training.callbacks.ProgressBar method)</a> +</li> + <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.init_validation_tqdm">init_validation_tqdm() (graphnet.training.callbacks.ProgressBar method)</a> +</li> + <li><a href="api/graphnet.data.pipeline.html#graphnet.data.pipeline.InSQLitePipeline">InSQLitePipeline (class in graphnet.data.pipeline)</a> </li> <li><a href="api/graphnet.data.extractors.utilities.types.html#graphnet.data.extractors.utilities.types.is_boost_class">is_boost_class() (in module graphnet.data.extractors.utilities.types)</a> </li> <li><a href="api/graphnet.data.extractors.utilities.types.html#graphnet.data.extractors.utilities.types.is_boost_enum">is_boost_enum() (in module graphnet.data.extractors.utilities.types)</a> </li> <li><a href="api/graphnet.utilities.filesys.html#graphnet.utilities.filesys.is_gcd_file">is_gcd_file() (in module graphnet.utilities.filesys)</a> +</li> + <li><a href="api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.is_graphnet_class">is_graphnet_class() (in module graphnet.utilities.config.parsing)</a> +</li> + <li><a href="api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.is_graphnet_module">is_graphnet_module() (in module graphnet.utilities.config.parsing)</a> </li> <li><a href="api/graphnet.utilities.filesys.html#graphnet.utilities.filesys.is_i3_file">is_i3_file() (in module graphnet.utilities.filesys)</a> </li> @@ -890,13 +1568,57 @@ <h2 id="K">K</h2> <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.TRUTH.KAGGLE">(graphnet.data.constants.TRUTH attribute)</a> </li> </ul></li> + <li><a href="api/graphnet.training.labels.html#graphnet.training.labels.Label.key">key (graphnet.training.labels.Label property)</a> +</li> + </ul></td> + <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.utils.html#graphnet.models.utils.knn_graph_batch">knn_graph_batch() (in module graphnet.models.utils)</a> +</li> + <li><a href="api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.KNNEdges">KNNEdges (class in graphnet.models.graphs.edges.edges)</a> +</li> + <li><a href="api/graphnet.models.graphs.graphs.html#graphnet.models.graphs.graphs.KNNGraph">KNNGraph (class in graphnet.models.graphs.graphs)</a> +</li> </ul></td> </tr></table> <h2 id="L">L</h2> <table style="width: 100%" class="indextable genindextable"><tr> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.training.labels.html#graphnet.training.labels.Label">Label (class in graphnet.training.labels)</a> +</li> + <li><a href="api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.list_all_submodules">list_all_submodules() (in module graphnet.utilities.config.parsing)</a> +</li> + <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.load">load() (graphnet.models.model.Model class method)</a> + + <ul> + <li><a href="api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.load">(graphnet.utilities.config.base_config.BaseConfig class method)</a> +</li> + </ul></li> + <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.load_module">load_module() (in module graphnet.data.dataset.dataset)</a> +</li> + <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.load_state_dict">load_state_dict() (graphnet.models.model.Model method)</a> +</li> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk">log_cmk() (graphnet.training.loss_functions.VonMisesFisherLoss class method)</a> +</li> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx">log_cmk_approx() (graphnet.training.loss_functions.VonMisesFisherLoss class method)</a> +</li> + </ul></td> + <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact">log_cmk_exact() (graphnet.training.loss_functions.VonMisesFisherLoss class method)</a> +</li> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCMK">LogCMK (class in graphnet.training.loss_functions)</a> +</li> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LogCoshLoss">LogCoshLoss (class in graphnet.training.loss_functions)</a> +</li> <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger">Logger (class in graphnet.utilities.logging)</a> +</li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column">loss_weight_column (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> +</li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value">loss_weight_default_value (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> +</li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table">loss_weight_table (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> +</li> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.LossFunction">LossFunction (class in graphnet.training.loss_functions)</a> </li> </ul></td> </tr></table> @@ -904,6 +1626,10 @@ <h2 id="L">L</h2> <h2 id="M">M</h2> <table style="width: 100%" class="indextable genindextable"><tr> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.training.utils.html#graphnet.training.utils.make_dataloader">make_dataloader() (in module graphnet.training.utils)</a> +</li> + <li><a href="api/graphnet.training.utils.html#graphnet.training.utils.make_train_validation_dataloader">make_train_validation_dataloader() (in module graphnet.training.utils)</a> +</li> <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.DataConverter.merge_files">merge_files() (graphnet.data.dataconverter.DataConverter method)</a> <ul> @@ -912,6 +1638,36 @@ <h2 id="M">M</h2> <li><a href="api/graphnet.data.sqlite.sqlite_dataconverter.html#graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.merge_files">(graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter method)</a> </li> </ul></li> + <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito.message">message() (graphnet.models.components.layers.EdgeConvTito method)</a> +</li> + <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.min_pool">min_pool() (in module graphnet.models.components.pool)</a> +</li> + <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.min_pool_x">min_pool_x() (in module graphnet.models.components.pool)</a> +</li> + <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model">Model (class in graphnet.models.model)</a> +</li> + <li><a href="api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.model_config">model_config (graphnet.utilities.config.base_config.BaseConfig attribute)</a> + + <ul> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.model_config">(graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> +</li> + <li><a href="api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig.model_config">(graphnet.utilities.config.model_config.ModelConfig attribute)</a> +</li> + <li><a href="api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig.model_config">(graphnet.utilities.config.training_config.TrainingConfig attribute)</a> +</li> + </ul></li> + <li><a href="api/graphnet.utilities.config.base_config.html#graphnet.utilities.config.base_config.BaseConfig.model_fields">model_fields (graphnet.utilities.config.base_config.BaseConfig attribute)</a> + + <ul> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.model_fields">(graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> +</li> + <li><a href="api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig.model_fields">(graphnet.utilities.config.model_config.ModelConfig attribute)</a> +</li> + <li><a href="api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig.model_fields">(graphnet.utilities.config.training_config.TrainingConfig attribute)</a> +</li> + </ul></li> + <li><a href="api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.ModelConfig">ModelConfig (class in graphnet.utilities.config.model_config)</a> +</li> <li> module @@ -925,6 +1681,22 @@ <h2 id="M">M</h2> <li><a href="api/graphnet.data.constants.html#module-graphnet.data.constants">graphnet.data.constants</a> </li> <li><a href="api/graphnet.data.dataconverter.html#module-graphnet.data.dataconverter">graphnet.data.dataconverter</a> +</li> + <li><a href="api/graphnet.data.dataloader.html#module-graphnet.data.dataloader">graphnet.data.dataloader</a> +</li> + <li><a href="api/graphnet.data.dataset.html#module-graphnet.data.dataset">graphnet.data.dataset</a> +</li> + <li><a href="api/graphnet.data.dataset.dataset.html#module-graphnet.data.dataset.dataset">graphnet.data.dataset.dataset</a> +</li> + <li><a href="api/graphnet.data.dataset.parquet.html#module-graphnet.data.dataset.parquet">graphnet.data.dataset.parquet</a> +</li> + <li><a href="api/graphnet.data.dataset.parquet.parquet_dataset.html#module-graphnet.data.dataset.parquet.parquet_dataset">graphnet.data.dataset.parquet.parquet_dataset</a> +</li> + <li><a href="api/graphnet.data.dataset.sqlite.html#module-graphnet.data.dataset.sqlite">graphnet.data.dataset.sqlite</a> +</li> + <li><a href="api/graphnet.data.dataset.sqlite.sqlite_dataset.html#module-graphnet.data.dataset.sqlite.sqlite_dataset">graphnet.data.dataset.sqlite.sqlite_dataset</a> +</li> + <li><a href="api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html#module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed">graphnet.data.dataset.sqlite.sqlite_dataset_perturbed</a> </li> <li><a href="api/graphnet.data.extractors.html#module-graphnet.data.extractors">graphnet.data.extractors</a> </li> @@ -963,6 +1735,8 @@ <h2 id="M">M</h2> <li><a href="api/graphnet.data.parquet.html#module-graphnet.data.parquet">graphnet.data.parquet</a> </li> <li><a href="api/graphnet.data.parquet.parquet_dataconverter.html#module-graphnet.data.parquet.parquet_dataconverter">graphnet.data.parquet.parquet_dataconverter</a> +</li> + <li><a href="api/graphnet.data.pipeline.html#module-graphnet.data.pipeline">graphnet.data.pipeline</a> </li> <li><a href="api/graphnet.data.sqlite.html#module-graphnet.data.sqlite">graphnet.data.sqlite</a> </li> @@ -979,6 +1753,66 @@ <h2 id="M">M</h2> <li><a href="api/graphnet.data.utilities.string_selection_resolver.html#module-graphnet.data.utilities.string_selection_resolver">graphnet.data.utilities.string_selection_resolver</a> </li> <li><a href="api/graphnet.deployment.html#module-graphnet.deployment">graphnet.deployment</a> +</li> + <li><a href="api/graphnet.deployment.i3modules.graphnet_module.html#module-graphnet.deployment.i3modules.graphnet_module">graphnet.deployment.i3modules.graphnet_module</a> +</li> + <li><a href="api/graphnet.models.html#module-graphnet.models">graphnet.models</a> +</li> + <li><a href="api/graphnet.models.coarsening.html#module-graphnet.models.coarsening">graphnet.models.coarsening</a> +</li> + <li><a href="api/graphnet.models.components.html#module-graphnet.models.components">graphnet.models.components</a> +</li> + <li><a href="api/graphnet.models.components.layers.html#module-graphnet.models.components.layers">graphnet.models.components.layers</a> +</li> + <li><a href="api/graphnet.models.components.pool.html#module-graphnet.models.components.pool">graphnet.models.components.pool</a> +</li> + <li><a href="api/graphnet.models.detector.html#module-graphnet.models.detector">graphnet.models.detector</a> +</li> + <li><a href="api/graphnet.models.detector.detector.html#module-graphnet.models.detector.detector">graphnet.models.detector.detector</a> +</li> + <li><a href="api/graphnet.models.detector.icecube.html#module-graphnet.models.detector.icecube">graphnet.models.detector.icecube</a> +</li> + <li><a href="api/graphnet.models.detector.prometheus.html#module-graphnet.models.detector.prometheus">graphnet.models.detector.prometheus</a> +</li> + <li><a href="api/graphnet.models.gnn.html#module-graphnet.models.gnn">graphnet.models.gnn</a> +</li> + <li><a href="api/graphnet.models.gnn.convnet.html#module-graphnet.models.gnn.convnet">graphnet.models.gnn.convnet</a> +</li> + <li><a href="api/graphnet.models.gnn.dynedge.html#module-graphnet.models.gnn.dynedge">graphnet.models.gnn.dynedge</a> +</li> + <li><a href="api/graphnet.models.gnn.dynedge_jinst.html#module-graphnet.models.gnn.dynedge_jinst">graphnet.models.gnn.dynedge_jinst</a> +</li> + <li><a href="api/graphnet.models.gnn.dynedge_kaggle_tito.html#module-graphnet.models.gnn.dynedge_kaggle_tito">graphnet.models.gnn.dynedge_kaggle_tito</a> +</li> + <li><a href="api/graphnet.models.gnn.gnn.html#module-graphnet.models.gnn.gnn">graphnet.models.gnn.gnn</a> +</li> + <li><a href="api/graphnet.models.graphs.html#module-graphnet.models.graphs">graphnet.models.graphs</a> +</li> + <li><a href="api/graphnet.models.graphs.edges.html#module-graphnet.models.graphs.edges">graphnet.models.graphs.edges</a> +</li> + <li><a href="api/graphnet.models.graphs.edges.edges.html#module-graphnet.models.graphs.edges.edges">graphnet.models.graphs.edges.edges</a> +</li> + <li><a href="api/graphnet.models.graphs.graph_definition.html#module-graphnet.models.graphs.graph_definition">graphnet.models.graphs.graph_definition</a> +</li> + <li><a href="api/graphnet.models.graphs.graphs.html#module-graphnet.models.graphs.graphs">graphnet.models.graphs.graphs</a> +</li> + <li><a href="api/graphnet.models.graphs.nodes.html#module-graphnet.models.graphs.nodes">graphnet.models.graphs.nodes</a> +</li> + <li><a href="api/graphnet.models.graphs.nodes.nodes.html#module-graphnet.models.graphs.nodes.nodes">graphnet.models.graphs.nodes.nodes</a> +</li> + <li><a href="api/graphnet.models.model.html#module-graphnet.models.model">graphnet.models.model</a> +</li> + <li><a href="api/graphnet.models.standard_model.html#module-graphnet.models.standard_model">graphnet.models.standard_model</a> +</li> + <li><a href="api/graphnet.models.task.html#module-graphnet.models.task">graphnet.models.task</a> +</li> + <li><a href="api/graphnet.models.task.classification.html#module-graphnet.models.task.classification">graphnet.models.task.classification</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#module-graphnet.models.task.reconstruction">graphnet.models.task.reconstruction</a> +</li> + <li><a href="api/graphnet.models.task.task.html#module-graphnet.models.task.task">graphnet.models.task.task</a> +</li> + <li><a href="api/graphnet.models.utils.html#module-graphnet.models.utils">graphnet.models.utils</a> </li> <li><a href="api/graphnet.pisa.html#module-graphnet.pisa">graphnet.pisa</a> </li> @@ -987,12 +1821,34 @@ <h2 id="M">M</h2> <li><a href="api/graphnet.pisa.plotting.html#module-graphnet.pisa.plotting">graphnet.pisa.plotting</a> </li> <li><a href="api/graphnet.training.html#module-graphnet.training">graphnet.training</a> +</li> + <li><a href="api/graphnet.training.callbacks.html#module-graphnet.training.callbacks">graphnet.training.callbacks</a> +</li> + <li><a href="api/graphnet.training.labels.html#module-graphnet.training.labels">graphnet.training.labels</a> +</li> + <li><a href="api/graphnet.training.loss_functions.html#module-graphnet.training.loss_functions">graphnet.training.loss_functions</a> +</li> + <li><a href="api/graphnet.training.utils.html#module-graphnet.training.utils">graphnet.training.utils</a> </li> <li><a href="api/graphnet.training.weight_fitting.html#module-graphnet.training.weight_fitting">graphnet.training.weight_fitting</a> </li> <li><a href="api/graphnet.utilities.html#module-graphnet.utilities">graphnet.utilities</a> </li> <li><a href="api/graphnet.utilities.argparse.html#module-graphnet.utilities.argparse">graphnet.utilities.argparse</a> +</li> + <li><a href="api/graphnet.utilities.config.html#module-graphnet.utilities.config">graphnet.utilities.config</a> +</li> + <li><a href="api/graphnet.utilities.config.base_config.html#module-graphnet.utilities.config.base_config">graphnet.utilities.config.base_config</a> +</li> + <li><a href="api/graphnet.utilities.config.configurable.html#module-graphnet.utilities.config.configurable">graphnet.utilities.config.configurable</a> +</li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#module-graphnet.utilities.config.dataset_config">graphnet.utilities.config.dataset_config</a> +</li> + <li><a href="api/graphnet.utilities.config.model_config.html#module-graphnet.utilities.config.model_config">graphnet.utilities.config.model_config</a> +</li> + <li><a href="api/graphnet.utilities.config.parsing.html#module-graphnet.utilities.config.parsing">graphnet.utilities.config.parsing</a> +</li> + <li><a href="api/graphnet.utilities.config.training_config.html#module-graphnet.utilities.config.training_config">graphnet.utilities.config.training_config</a> </li> <li><a href="api/graphnet.utilities.decorators.html#module-graphnet.utilities.decorators">graphnet.utilities.decorators</a> </li> @@ -1001,9 +1857,17 @@ <h2 id="M">M</h2> <li><a href="api/graphnet.utilities.imports.html#module-graphnet.utilities.imports">graphnet.utilities.imports</a> </li> <li><a href="api/graphnet.utilities.logging.html#module-graphnet.utilities.logging">graphnet.utilities.logging</a> +</li> + <li><a href="api/graphnet.utilities.maths.html#module-graphnet.utilities.maths">graphnet.utilities.maths</a> </li> </ul></li> </ul></td> + <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.MSELoss">MSELoss (class in graphnet.training.loss_functions)</a> +</li> + <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.MulticlassClassificationTask">MulticlassClassificationTask (class in graphnet.models.task.classification)</a> +</li> + </ul></td> </tr></table> <h2 id="N">N</h2> @@ -1011,9 +1875,59 @@ <h2 id="N">N</h2> <td style="width: 33%; vertical-align: top;"><ul> <li><a href="api/graphnet.data.extractors.i3extractor.html#graphnet.data.extractors.i3extractor.I3Extractor.name">name (graphnet.data.extractors.i3extractor.I3Extractor property)</a> </li> + <li><a href="api/graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN.nb_inputs">nb_inputs (graphnet.models.gnn.gnn.GNN property)</a> + + <ul> + <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTask.nb_inputs">(graphnet.models.task.classification.BinaryClassificationTask attribute)</a> +</li> + <li><a href="api/graphnet.models.task.classification.html#graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs">(graphnet.models.task.classification.BinaryClassificationTaskLogits attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs">(graphnet.models.task.reconstruction.AzimuthReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs">(graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs">(graphnet.models.task.reconstruction.DirectionReconstructionWithKappa attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs">(graphnet.models.task.reconstruction.EnergyReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs">(graphnet.models.task.reconstruction.EnergyReconstructionWithPower attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs">(graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs">(graphnet.models.task.reconstruction.InelasticityReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs">(graphnet.models.task.reconstruction.PositionReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs">(graphnet.models.task.reconstruction.TimeReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs">(graphnet.models.task.reconstruction.VertexReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs">(graphnet.models.task.reconstruction.ZenithReconstruction attribute)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs">(graphnet.models.task.reconstruction.ZenithReconstructionWithKappa attribute)</a> +</li> + <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.IdentityTask.nb_inputs">(graphnet.models.task.task.IdentityTask property)</a> +</li> + <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task.nb_inputs">(graphnet.models.task.task.Task property)</a> +</li> + </ul></li> </ul></td> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.gnn.gnn.html#graphnet.models.gnn.gnn.GNN.nb_outputs">nb_outputs (graphnet.models.gnn.gnn.GNN property)</a> + + <ul> + <li><a href="api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs">(graphnet.models.graphs.nodes.nodes.NodeDefinition property)</a> +</li> + </ul></li> <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.RepeatFilter.nb_repeats_allowed">nb_repeats_allowed (graphnet.utilities.logging.RepeatFilter attribute)</a> +</li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth">node_truth (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> +</li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table">node_truth_table (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> +</li> + <li><a href="api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition">NodeDefinition (class in graphnet.models.graphs.nodes.nodes)</a> +</li> + <li><a href="api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodesAsPulses">NodesAsPulses (class in graphnet.models.graphs.nodes.nodes)</a> </li> </ul></td> </tr></table> @@ -1021,6 +1935,12 @@ <h2 id="N">N</h2> <h2 id="O">O</h2> <table style="width: 100%" class="indextable genindextable"><tr> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.on_train_epoch_end">on_train_epoch_end() (graphnet.training.callbacks.ProgressBar method)</a> +</li> + </ul></td> + <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar.on_train_epoch_start">on_train_epoch_start() (graphnet.training.callbacks.ProgressBar method)</a> +</li> <li><a href="api/graphnet.utilities.argparse.html#graphnet.utilities.argparse.Options">Options (class in graphnet.utilities.argparse)</a> </li> </ul></td> @@ -1032,21 +1952,69 @@ <h2 id="P">P</h2> <li><a href="api/graphnet.data.utilities.random.html#graphnet.data.utilities.random.pairwise_shuffle">pairwise_shuffle() (in module graphnet.data.utilities.random)</a> </li> <li><a href="api/graphnet.data.parquet.parquet_dataconverter.html#graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter">ParquetDataConverter (class in graphnet.data.parquet.parquet_dataconverter)</a> +</li> + <li><a href="api/graphnet.data.dataset.parquet.parquet_dataset.html#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset">ParquetDataset (class in graphnet.data.dataset.parquet.parquet_dataset)</a> </li> <li><a href="api/graphnet.data.utilities.parquet_to_sqlite.html#graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter">ParquetToSQLiteConverter (class in graphnet.data.utilities.parquet_to_sqlite)</a> +</li> + <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.parse_graph_definition">parse_graph_definition() (in module graphnet.data.dataset.dataset)</a> +</li> + <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.path">path (graphnet.data.dataset.dataset.Dataset property)</a> + + <ul> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.path">(graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> +</li> + </ul></li> + <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.PiecewiseLinearLR">PiecewiseLinearLR (class in graphnet.training.callbacks)</a> </li> <li><a href="api/graphnet.pisa.plotting.html#graphnet.pisa.plotting.plot_1D_contour">plot_1D_contour() (in module graphnet.pisa.plotting)</a> </li> - </ul></td> - <td style="width: 33%; vertical-align: top;"><ul> <li><a href="api/graphnet.pisa.plotting.html#graphnet.pisa.plotting.plot_2D_contour">plot_2D_contour() (in module graphnet.pisa.plotting)</a> </li> <li><a href="api/graphnet.utilities.argparse.html#graphnet.utilities.argparse.Options.pop_default">pop_default() (graphnet.utilities.argparse.Options method)</a> +</li> + </ul></td> + <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.PositionReconstruction">PositionReconstruction (class in graphnet.models.task.reconstruction)</a> +</li> + <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.predict">predict() (graphnet.models.model.Model method)</a> + + <ul> + <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.predict">(graphnet.models.standard_model.StandardModel method)</a> +</li> + </ul></li> + <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.predict_as_dataframe">predict_as_dataframe() (graphnet.models.model.Model method)</a> + + <ul> + <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.predict_as_dataframe">(graphnet.models.standard_model.StandardModel method)</a> +</li> + </ul></li> + <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.prediction_labels">prediction_labels (graphnet.models.standard_model.StandardModel property)</a> +</li> + <li><a href="api/graphnet.training.callbacks.html#graphnet.training.callbacks.ProgressBar">ProgressBar (class in graphnet.training.callbacks)</a> +</li> + <li><a href="api/graphnet.models.detector.prometheus.html#graphnet.models.detector.prometheus.Prometheus">Prometheus (class in graphnet.models.detector.prometheus)</a> </li> <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.FEATURES.PROMETHEUS">PROMETHEUS (graphnet.data.constants.FEATURES attribute)</a> <ul> <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.TRUTH.PROMETHEUS">(graphnet.data.constants.TRUTH attribute)</a> +</li> + </ul></li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps">pulsemaps (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> +</li> + </ul></td> +</tr></table> + +<h2 id="Q">Q</h2> +<table style="width: 100%" class="indextable genindextable"><tr> + <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.query_table">query_table() (graphnet.data.dataset.dataset.Dataset method)</a> + + <ul> + <li><a href="api/graphnet.data.dataset.parquet.parquet_dataset.html#graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table">(graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset method)</a> +</li> + <li><a href="api/graphnet.data.dataset.sqlite.sqlite_dataset.html#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table">(graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset method)</a> </li> </ul></li> </ul></td> @@ -1055,7 +2023,11 @@ <h2 id="P">P</h2> <h2 id="R">R</h2> <table style="width: 100%" class="indextable genindextable"><tr> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.graphs.edges.edges.html#graphnet.models.graphs.edges.edges.RadialEdges">RadialEdges (class in graphnet.models.graphs.edges.edges)</a> +</li> <li><a href="api/graphnet.pisa.plotting.html#graphnet.pisa.plotting.read_entry">read_entry() (in module graphnet.pisa.plotting)</a> +</li> + <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.Coarsening.reduce_options">reduce_options (graphnet.models.coarsening.Coarsening attribute)</a> </li> <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.RepeatFilter">RepeatFilter (class in graphnet.utilities.logging)</a> </li> @@ -1063,7 +2035,11 @@ <h2 id="R">R</h2> </li> </ul></td> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.components.layers.html#graphnet.models.components.layers.EdgeConvTito.reset_parameters">reset_parameters() (graphnet.models.components.layers.EdgeConvTito method)</a> +</li> <li><a href="api/graphnet.data.utilities.string_selection_resolver.html#graphnet.data.utilities.string_selection_resolver.StringSelectionResolver.resolve">resolve() (graphnet.data.utilities.string_selection_resolver.StringSelectionResolver method)</a> +</li> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.RMSELoss">RMSELoss (class in graphnet.training.loss_functions)</a> </li> <li><a href="api/graphnet.data.utilities.parquet_to_sqlite.html#graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter.run">run() (graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter method)</a> </li> @@ -1075,6 +2051,10 @@ <h2 id="R">R</h2> <h2 id="S">S</h2> <table style="width: 100%" class="indextable genindextable"><tr> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.save">save() (graphnet.models.model.Model method)</a> +</li> + <li><a href="api/graphnet.utilities.config.configurable.html#graphnet.utilities.config.configurable.Configurable.save_config">save_config() (graphnet.utilities.config.configurable.Configurable method)</a> +</li> <li><a href="api/graphnet.data.dataconverter.html#graphnet.data.dataconverter.DataConverter.save_data">save_data() (graphnet.data.dataconverter.DataConverter method)</a> <ul> @@ -1083,7 +2063,19 @@ <h2 id="S">S</h2> <li><a href="api/graphnet.data.sqlite.sqlite_dataconverter.html#graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.save_data">(graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter method)</a> </li> </ul></li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.save_dataset_config">save_dataset_config() (in module graphnet.utilities.config.dataset_config)</a> +</li> + <li><a href="api/graphnet.utilities.config.model_config.html#graphnet.utilities.config.model_config.save_model_config">save_model_config() (in module graphnet.utilities.config.model_config)</a> +</li> + <li><a href="api/graphnet.training.utils.html#graphnet.training.utils.save_results">save_results() (in module graphnet.training.utils)</a> +</li> + <li><a href="api/graphnet.models.model.html#graphnet.models.model.Model.save_state_dict">save_state_dict() (graphnet.models.model.Model method)</a> +</li> <li><a href="api/graphnet.data.sqlite.sqlite_utilities.html#graphnet.data.sqlite.sqlite_utilities.save_to_sql">save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)</a> +</li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.seed">seed (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> +</li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.selection">selection (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> </li> <li><a href="api/graphnet.data.extractors.utilities.collections.html#graphnet.data.extractors.utilities.collections.serialise">serialise() (in module graphnet.data.extractors.utilities.collections)</a> </li> @@ -1095,15 +2087,37 @@ <h2 id="S">S</h2> </ul></li> </ul></td> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.graphs.nodes.nodes.html#graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs">set_number_of_inputs() (graphnet.models.graphs.nodes.nodes.NodeDefinition method)</a> +</li> <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.setLevel">setLevel() (graphnet.utilities.logging.Logger method)</a> +</li> + <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.shared_step">shared_step() (graphnet.models.standard_model.StandardModel method)</a> </li> <li><a href="api/graphnet.data.sqlite.sqlite_dataconverter.html#graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter">SQLiteDataConverter (class in graphnet.data.sqlite.sqlite_dataconverter)</a> +</li> + <li><a href="api/graphnet.data.dataset.sqlite.sqlite_dataset.html#graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset">SQLiteDataset (class in graphnet.data.dataset.sqlite.sqlite_dataset)</a> +</li> + <li><a href="api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html#graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed">SQLiteDatasetPerturbed (class in graphnet.data.dataset.sqlite.sqlite_dataset_perturbed)</a> </li> <li><a href="api/graphnet.utilities.argparse.html#graphnet.utilities.argparse.ArgumentParser.standard_arguments">standard_arguments (graphnet.utilities.argparse.ArgumentParser attribute)</a> +</li> + <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel">StandardModel (class in graphnet.models.standard_model)</a> +</li> + <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.std_pool">std_pool() (in module graphnet.models.components.pool)</a> +</li> + <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.std_pool_x">std_pool_x() (in module graphnet.models.components.pool)</a> </li> <li><a href="api/graphnet.utilities.logging.html#graphnet.utilities.logging.Logger.stream_handlers">stream_handlers (graphnet.utilities.logging.Logger property)</a> +</li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.string_selection">string_selection (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> </li> <li><a href="api/graphnet.data.utilities.string_selection_resolver.html#graphnet.data.utilities.string_selection_resolver.StringSelectionResolver">StringSelectionResolver (class in graphnet.data.utilities.string_selection_resolver)</a> +</li> + <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool">sum_pool() (in module graphnet.models.components.pool)</a> +</li> + <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool_and_distribute">sum_pool_and_distribute() (in module graphnet.models.components.pool)</a> +</li> + <li><a href="api/graphnet.models.components.pool.html#graphnet.models.components.pool.sum_pool_x">sum_pool_x() (in module graphnet.models.components.pool)</a> </li> </ul></td> </tr></table> @@ -1111,18 +2125,46 @@ <h2 id="S">S</h2> <h2 id="T">T</h2> <table style="width: 100%" class="indextable genindextable"><tr> <td style="width: 33%; vertical-align: top;"><ul> - <li><a href="api/graphnet.data.extractors.utilities.collections.html#graphnet.data.extractors.utilities.collections.transpose_list_of_dicts">transpose_list_of_dicts() (in module graphnet.data.extractors.utilities.collections)</a> + <li><a href="api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig.target">target (graphnet.utilities.config.training_config.TrainingConfig attribute)</a> +</li> + <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.target_labels">target_labels (graphnet.models.standard_model.StandardModel property)</a> +</li> + <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task">Task (class in graphnet.models.task.task)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.TimeReconstruction">TimeReconstruction (class in graphnet.models.task.reconstruction)</a> +</li> + <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.train">train() (graphnet.models.standard_model.StandardModel method)</a> +</li> + <li><a href="api/graphnet.models.task.task.html#graphnet.models.task.task.Task.train_eval">train_eval() (graphnet.models.task.task.Task method)</a> +</li> + <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.training_step">training_step() (graphnet.models.standard_model.StandardModel method)</a> </li> </ul></td> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.utilities.config.training_config.html#graphnet.utilities.config.training_config.TrainingConfig">TrainingConfig (class in graphnet.utilities.config.training_config)</a> +</li> + <li><a href="api/graphnet.data.extractors.utilities.collections.html#graphnet.data.extractors.utilities.collections.transpose_list_of_dicts">transpose_list_of_dicts() (in module graphnet.data.extractors.utilities.collections)</a> +</li> + <li><a href="api/graphnet.utilities.config.parsing.html#graphnet.utilities.config.parsing.traverse_and_apply">traverse_and_apply() (in module graphnet.utilities.config.parsing)</a> +</li> <li><a href="api/graphnet.data.constants.html#graphnet.data.constants.TRUTH">TRUTH (class in graphnet.data.constants)</a> </li> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.truth">truth (graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> +</li> + <li><a href="api/graphnet.data.dataset.dataset.html#graphnet.data.dataset.dataset.Dataset.truth_table">truth_table (graphnet.data.dataset.dataset.Dataset property)</a> + + <ul> + <li><a href="api/graphnet.utilities.config.dataset_config.html#graphnet.utilities.config.dataset_config.DatasetConfig.truth_table">(graphnet.utilities.config.dataset_config.DatasetConfig attribute)</a> +</li> + </ul></li> </ul></td> </tr></table> <h2 id="U">U</h2> <table style="width: 100%" class="indextable genindextable"><tr> <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.coarsening.html#graphnet.models.coarsening.unbatch_edge_index">unbatch_edge_index() (in module graphnet.models.coarsening)</a> +</li> <li><a href="api/graphnet.training.weight_fitting.html#graphnet.training.weight_fitting.Uniform">Uniform (class in graphnet.training.weight_fitting)</a> </li> </ul></td> @@ -1136,6 +2178,24 @@ <h2 id="U">U</h2> </ul></td> </tr></table> +<h2 id="V">V</h2> +<table style="width: 100%" class="indextable genindextable"><tr> + <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.standard_model.html#graphnet.models.standard_model.StandardModel.validation_step">validation_step() (graphnet.models.standard_model.StandardModel method)</a> +</li> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.VertexReconstruction">VertexReconstruction (class in graphnet.models.task.reconstruction)</a> +</li> + </ul></td> + <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisher2DLoss">VonMisesFisher2DLoss (class in graphnet.training.loss_functions)</a> +</li> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisher3DLoss">VonMisesFisher3DLoss (class in graphnet.training.loss_functions)</a> +</li> + <li><a href="api/graphnet.training.loss_functions.html#graphnet.training.loss_functions.VonMisesFisherLoss">VonMisesFisherLoss (class in graphnet.training.loss_functions)</a> +</li> + </ul></td> +</tr></table> + <h2 id="W">W</h2> <table style="width: 100%" class="indextable genindextable"><tr> <td style="width: 33%; vertical-align: top;"><ul> @@ -1156,6 +2216,18 @@ <h2 id="W">W</h2> </ul></td> </tr></table> +<h2 id="Z">Z</h2> +<table style="width: 100%" class="indextable genindextable"><tr> + <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstruction">ZenithReconstruction (class in graphnet.models.task.reconstruction)</a> +</li> + </ul></td> + <td style="width: 33%; vertical-align: top;"><ul> + <li><a href="api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.ZenithReconstructionWithKappa">ZenithReconstructionWithKappa (class in graphnet.models.task.reconstruction)</a> +</li> + </ul></td> +</tr></table> + </article> @@ -1180,7 +2252,7 @@ <h2 id="W">W</h2> </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/index.html b/index.html index 7e48c3ed6..fe3017fed 100644 --- a/index.html +++ b/index.html @@ -413,7 +413,7 @@ <h2 id="acknowledgements">Acknowledgements<a class="headerlink" href="#acknowled </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/install.html b/install.html index fa6e7827c..a9b5b61a3 100644 --- a/install.html +++ b/install.html @@ -505,7 +505,7 @@ <h2 id="running-in-docker">Running in Docker<a class="headerlink" href="#running </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/objects.inv b/objects.inv index 3ef123ab7aaf3155a34935402f71d0a18a79a679..e2a8e9d3320ab422b2aefc6b85dc595c1a5818d6 100644 GIT binary patch delta 6145 zcmV+c82;zL8^SP<cYj^oa^p6ZzSmQ*YUVPn-NaQ}TbZkl?e2JEJ09CTsoEPBL_!i` z{xtyFR-R^_WuI*EM~ak4;s7EkUvygl&-uOs;NX`4xS^=nW|#-X?&ma*%LE7i8K8pv z$HqKZZh4ygO+Whd_u$i_eh}pu<0#{-U$++NM~M+f{Q&amuzxPZjJ6yUZ~5i=^YgdI z>n8~uC5W-$M>{SAaCvoo{c!pF<F!OaoKvzYIUWuD=JNXT+b`Gu{a@3-|MvOo*V~Cq ze|Y@;?)v%H>u*of;D3Ag`uO>3T;rb~zdiptnvW>7X&gjDPvadLdji+cz!Uj}1|QEU z<bOiw6;qC>rGGtY$y(Z^w(OM%T{c<z9j}AL34Mv*aU`=trUmnD9#2EFCJ7E%xn7fZ zNmbCi5D#|KkZmvzQ&fcOGK)CLv#Ar8V!FYhOd^|_ilJ8=)?MnTL|=scPkNv*s=pcB z<tB<%9@(qm2_)oVI*tfMTO8&kFG{{I?b^vkVj&sf;eRH{S11Wd7UTEvNRk}Ipk_-8 z0@^GiD1G#{1|RmGCN>asxu$ulyq<}gEt0w>8-V|VnW8*N5EnH#1LZXsH_k)06>XmO z#}RPZ&^W#+1t>aZ0&q+(?B!*ilxcQH7$e!{`aQx$)g<FO@FSGfPIt9W(y$-;V=qx0 zhY4EY1b^U!C(sB-kAz!h2-H2FK%}Y{xLTb+R3N?ub^j-Dd@nKGg&eJtiF1w9l5h7! zQ>5uSV>lJ~Q#!S5bwX47A+3}LMS&>8p^^z>yp}zjX#&P%6D*?ap^cf51~z7uwD}vl zHJm-#r*SVIblAx^1}!<&1dpocX%G9J5Ys5ybAMk812fwQ>pzag+^KsC@AB!Nw<73N z!UIN}#O_hJ<Czl5mbIiS9G_AYdz4cOrGe7QpogZ3p5=6C;jwA>ol_L?oZ41H_kxzh zHU!B<9e|@Zl)RXhHy77U*8Z#;BK2WgW+=6H-J`S0a)zuGj7=;C;t8(TL*~ooJ?|Vz zUw>mHY7q8v<U?H-e?HQ^x%i@^e7w8clG9DNiST8)!g_7^W7u+qadDYb{BbP$RtT8* z90om<3B&g}Vfa~2F1{paGp|h>oMB2LuiA5{>r#7;bZ;)c>L{Pu(`|{xgzaucNlbB+ zd!@cZ-6aFM)I}bkMEoStLxM8=s}gw>B7Z)4(rcf^u#Bf=o~hLi=9Mv>04`b64RCXD z|9mInZ)Lnsh78zcF-S~SM42Rheaw-vOI92!i#Ef=`Q$;5F*Yj*Lw;#*DmtaRr6|&T z^Yqy(y~hSvn8mDz`85b_O_$UjXx?1>^V{{)?>_16vcW=5F{geBZfm+E_(1dK;(t*_ z_$0W?1`9bW5|ZJx@T<<p+Aaz2(7w5NQjvEL*Eg3wS<q*Z#T2!g>6Ha-O_$^!Xx?0? zQIAieyKJzLm#JTJ8+uNOZR&}ne!KHZ=n?xMZOKVOIKd991+@x1aeI}hp6hurHq=jK zU#=6xInKfiGmhghCMq7$opsq6W`CK7NlXk%Q6dKS=QBY(&RCH%911&ln6FjFuz3wy zQ<VC7q}KFt4Wp_4goK4@o^c!@n&f_IZ(8DMLY8~Gf<L?58k48t3*NC1WpP-@8K-ni z%w{u%?+U--SBgHRuN7Amqn9w+MG1*QF%v2&gS>{anc`$agk-f_{VD2A$bWN@{S=4A zj&ILmhbsI~m;PC;Q5&{!y23PkMWT^hCBE+HYk{bbt2`HZqWj_VaO12@XQw8Ss0k~4 zri7bC;JIn#j|oU&ub^SCp1WR+Lbsj{>Yl>eeB%q;p$f2N3?mMAm=5p$o69Kp&d|9m zf^H!7vxgK|&OUAKX~b-vSbvG#Wcp3(Lj!)m;acd&u%OKYpiKkE?{J|NbU@Rev++I} zSU0U^oXo#O*-no$<c4fD>diu&7JO&%zZpz(YB6T82i{oBU<|aK7}*qc0)v*3ZgAA( z!rD)dY=Rt>L?Ng69L#!1F6zqk+|ae59xbPKYOf_~8H%>Ro$Z6vaDUE7O#DPdy?1GD zPBqh36-rF9gydW{$cHN={0YAk23X}8n^q@&UUFJyGK<3~k7q0E35yn5NukwpFINrh zwB7a*(SW8uXEjEB&UGII9Jg3)8OI3eOqTBNBul7h7xEOG?^+z3*+#rTguW5E&I3G( z(oHzTBzkO;Z9<04a({40G0T%z>nIq208VW`!jpRPR6TfX*aq=1i)o|Uv@`->r*kq6 z7fHTLaW=F*QOgWu{9cA!XaTmYo;FSrSQ&b?Pvc&GmEn8*d~<PE8`E9VU|ZacxVgB= z)|lcf!e_NYZ7I7<Fv>8U#t`uUCj<SHe%@cE=6RIMQzDsgXMZcyUep0~`MH|kUACAy zv!aK{!Y~8ya-GvRMC1Mz;$#Mj<1)fwUevE~duPU;TqWaqdiDF=WcekqKWFhXN%0?I zqLIH%Ni-RDnXM2PQzaa4uzkMd#E}+xh6is{IT8ibv|2PlJ4~IxwT3&#@@kh|i~Ny8 zHhcDW(mC@uRDZ$K2%qK4`l(Z;NajgWm4VNM=aVL@Ni=`bmMLZoZSahYV$>%O4<Wf+ z#;d9~K6R1jiA$}P@-B5+$!}<07U60KYV10w;2e7IgPgFLS=d~3Mark7OGy@1MKv`? z_&y~I<MHGvsW>K1NtsTc62qaE8)L$%LE9XZqh?+#8h^3jZJ7sx)<&ilbcgomSY9>A zhB~_*E;o{;Iq*}xJQ!^MM`8qatv+mKTME3Ki~_uvDh2zQxIrp64!=_JCbS-|8=Yz9 z1pbCLn`DOMX`Y!AXFK#aHqq3{D`+kuamnD-danbo*$thC79}Vb4;kI5{PGdPyU8+Z z)(IJFR)0wu`MP-UV%TF@7u=<so!k#xThRv(#0B?0V1js;$+Z?xh8S5;hE%d<4FQrN zEqEZVQ--(@gnyEZahGh-8ZMb*G)z+d=jQ(D*&&~reC|X(8Rx?IqWX@In`IAhCN}h( zva+dX()oPz{JTqj9#MO8dSxw`9NHQ}TSWQq>wo<{B%^9b3zky`fI+iWNV1~T3%Gr^ zmv{o}`z^;omSKBeuf^;oq{XC*fH9S@6S5D!FN<2lyJ}sCHt2UfoL*#b$2^g>j1m%K zq?!aaaT0&Kzn_EV5yb>0Q!#bjlLzs|Lp`WN9_gv~=gw}~O@9KftB3j-z?`y|Ic2JH zNPm?omprt4NG3vHyOl81Y%T18@5{2h7ymo^+7YK0nX7}1#Ol_^j!G@x5@n@2XP2)V zLx`zU&ptD?4~30czUq(FOVuB%V`{0p3USE92A~CLxhCA2yDo4*cXfe;-2CefEANyY z8!u#BZm5>$)yIS`D4qy1M67*?VJv*P-G6v&yb*I(enBG^ej36mXj|M`QxC+SOb}Fc zyQo|_)LX1lWL1(gd)qk@qpY8DJyGt-O37c4H~gb8LtC6t$VjzcavnNp#UGQ*tw|V@ zS?&$icbx$S9itGlEfUnCUBSgXDdmXDdj#f9C8vmFGo?>bsombQW$7ytb9bbiXMa>u zj@XN}V~|4D?m!8tHlM6Xgh2D<4jmt!<u63uCv@40d^RZ?>V|sxPbFT8$uBAs9(GoG z+b-Ve%{r=~=QF9bf_gSHZu1S{>}>v`=|8B~92oa;e8(6zuLs?UJRmA=j`l`g6Kl?4 zgB#Sg4y-{v(E0oqNlU(cJYu(|u78XCB-Rls#dJ3L)rIlxajycS+wceB+uu-8p!2v% zP5%-7<=D6n;y8Q5=JlW}LaXzZ=kXz%0cZ9iuzB<Qk=3+$J?T2bbhDeeqv|3(v&RgY zG{2X0nsa{7kY4(bzhQb77pd7lpsyU;_A#7hZ`a50o^M%%DMuvJ_hy~NhkqJ2d{9s7 zvhd^hQ`5S6z3L{z!ePYxk?^_utbXb<`w+me`CVtgsChl+A!p=xv^BZ6bkm;MQ!GZ! z?<GA3&FdlZa)p__<CyZyo-tt1{9e&v%)B1)FL`Iy`^-MkMS5nh7&K{qzvwh)UeEYn zoRQg=b-PH<>=}b5&F>kV=6{^uGv-^48)W~0zVZ<k;)d;-*n8mn2H=|q)PFZI&bSzK zgI}lZ@!`P39%b3BayqN){uqBom*YuJETb2XGS_KGpF8|)r9NM2D<<gXKzC&LPC$3; zNqtlA(F*Y^O2#l<-t%cchoPjbZJm~iP!g4i-svkhYen*mx@e^#*MGM(cpwDtyT5q2 z&C|Ttl4Lg){!5mH`k;@x9Pb2J`7O}?yp>P;^He=u6V3;3t87Zau;kfxo|p8C2oAh0 zI0gnH%GmJ7Snl|jk^Me;8pBT`GyH+cX3Iy+lqR6E_)SHyam>~GwZpQAkukW>MIop~ zl0*Eg60<m3Ehc}2KYv~|3u@5HHogt`rwM>5MsX-c8N;7&v|(1eX#2CXlIYLLT0*pT z$qGd;41iY4y5KG!2#&D4zWbJB7}49uUjOeSa%NCpV@n3At~V3SHxx7W3!>wvVvlXU zCMSBeh)tS^r94(w4eLVwJx*t!<Cx9IrA9L|ku~<3PtPbx^nXyT$mJEgSv)JB7<y4Y z7UO=1Z)t(FAfWj%g7uyodu?q(4`SZRKV#!9U+`{dIPU^0=73?<w}3LWr2xs?wu0GZ zOhxRq=-cvQcWb7hmK^Hf1rg3jRp0$^WhNUF^WEl~%e&tuCs@iVIXv|pjmkR79)~Gs z+sVs@u`kt_ZGRq`IWq6=>3ZUfdz_x)UwL*X27v4fVOvapx|)^lVqUu0$%MovjMA4W zy|AA(M1|-T-_Jyj-%-Rz<+dz|wdDZ4v01gs6SZcw9{lcg7x)mWoQLceec71MDSFZ0 z{F65$46iYGLF>V4V@N&?m%4dCjFHQPkCZ)nAt$u1&42YM06gYE;g>&iOtahk&EBYV zQGqzVWkhsk>4XfR6}E9{hHpoKa{<1P2lzN%;{%m+Q^;?qK;SA_7u@A%x+&#HmL6oN zKezi>+~D^|vXO}&el<H1>l@)}mfJzyUEaRDzwf_NG-@syT)CDjx9m=5`Lf81(7d!8 zjihm-kAKC3wIsxUjimapr!#&-%a3dMi}-z6(^|^9kJX26DevqUyeMQM-ndMQ*-3|Q zZkmWxy{T(B6`*yRfPYPJ60_O(<>(M4N!>Wg(N>8Ehe4;(R~HFY-$4de-lm#E?QN-D z9#F$7@G-L|WuryxTO^#29CXXtg#2olJvqOeB7evW2Aw;HwxWZF9VlMupldJDMeon+ z(Oryefcf}<V@o!&5nY$qp3NBb@kH?`Dbw>BqSiUgjOK%xQ|%j5z0ilTidX%`*{ty& zn@nPe&md)FL*%urA*UtZeh@Q0fR_~L%(<aHE|{X?ysp78j?ZjBvVUcLJ}VBs%6?X3 z&VOH1&sRy+!1s7u3Z7rN8*tcN0EfU$HHQIkOHEzb=sNg4CiY`ox-Kt{&hI?z9*tV% zbeNedUC)qq|1rwQCisj_PBKAN@WbSTtG6q)TzDOPtJJ6AP^Bx4YI#UXX0i#APUg5g z)MAIzR#J0u5_Uj2C!s0Gi5U~ZbaAqBrGFVGPb8<JZcv8R*MLz`44=U`g|WVv*ty#~ z&~&IKhnmkO?ArfzT5t;Wvi+tU-VTm$2}U&8g^cG#A%(;og;a;C&S{o;?|V)Y{5#C4 zZ1L$i$G1fv8l29G7)LopJg1<S^?PYZ!>}fzQQd)nji>{v_$yjbauKR9tRZc&SbyNR zV)%;RInJ1wE{+Q?JG>}X#!munGG;t%k`_7Tpu#kwkix7*MQ)p~7K>IfdO;hY=`Ah* zi9C!nErX~X(RMltiochH;ySy=V@UGN1~`^BqL31<L~q4k(B7IBHVAYXQ4t*to$x&g zKE|tZGZ8kXG^Z0`%W55N#b7frpMR&Zzi9PfnVf)R>wKadVwm6JS3H&DFitT_r!Rsx zMC${T<EO%)hxy>~e?N*xTwuh%kc5xkHZ-ASRWJ-wSud|KpeoGdv@(AKue(DSQ^dER zzE=rC+GQ<>3l=8i1=x!;z$GeL<s}aokAIHyDEK7QWQ2Mie3m~KO`Mis6MtXJKissW zm+B50Bv&u{^lAx>PzT#Ma`M*;As=dO`RiB_U`1QZYKr<#Jb#sUfYD*pmgwHhI;scV zOm(f$T5gZ4!#$vEpa7dfdpHs3$8y7T#4Zs1UP7^ru*Nn4J9hOqbV^5$a*I2(r;mLX z^%27;wX9~)zpcRLk6Y3@Rezhb;B*~egtEIV=}I)G{)X&9x=h+X?J!Qv0V#^q1bL8u zFQ+N?;X%+Pxu^qRf`*PJYed-C%w8lP(nHmX`!L>S!UkuUis_4o35Tkc35W4E6Slh* zC9%Be%*Toj?aGjjM>bnBo|bv04?_D`)2UzC)A`V5k!XtKr;WUv<bO!FvgGK2)e13H z_p+d=SF<-GZOoz_W?qibR;;+&u{LswIrSoMD^}#~SQ|Mj5|ZJx@GAOa<;sZTCpJUW z4h0`0+KQC{?N}>$nR-z-)N0~J+%WC_)D)DleTXWkgFqg2#uA&8knWrf?P}CpM?fu5 zUtRSf*NU1~|AtcXU4P-lsPr_U`q{w;8g~N6n_WXi;Iv^-^?0l~G@U3&fo09L5E?0i zZ9*Mzkbx^`z<NB$1Ln<u;?c-(C}Y|<z^dD$)FK|L#6TK{W??T{c?8cxNGS~{wmqzl z9%QE|i}MspsiifK{yhkOxc5&UZ_blfo2vi-Ov&{<KClWGNq@eR8?yR&YY{Wmt;2m# zqSvW5F91Mm;i`8sK|8vM9LlpHt!8ri^QED9xEaKW^*~vZA4D!XE~e~dG51y~9R0gX zUG#|Z7^qr_wIJ;!tFE&O_W-9wp5dWos3XY=xC&c<mkHWo>VmHoG>!gkqbTw`aUhW5 zrZEw}I>BSi@qgeT?Le;bLme7~bjS7~sb1niIgGrizdA6?P|#d1dE=6SN%krx`|EHg z)h5eq&_cx2!Yu2<kc+6VL+ixHqV;O44D{m0y9~!048b*`rUZ4sh8zA#GR7UywL*(S zM_HL(&}Nh1kQ1J}@HN!htc?e~+`!Vcm#_?5eNam^T7PUd0JfBgfI=!+vot4vT5&zl zT{w=8<H3jy)f!pVh_w<k7dmix$OAx7WF@PC)lVy~2RevD{jtHSKaBxECsKX81_-h( zT@59haf`l7%5Bi~9p-`jwE%vB9E95O(ej`jw*Y=v0z}vGEsCHWw-~(19Yg}iE`7<B z1#Ev6l7E_X9k`pi-MkN>zT7;JzZSh$QV#<(ujC0#V9|dU^DtQJ#^#_3EjsT6>P2eZ z+#FP)h2?HlJ6!u3)}Rh9hHe?P!!~Xs4fO8?<a#s!@bPtOK|5|SaJ`ltf7tc<?&7k- zc3kxowBr`tTLpTdVT%(TcR$$yTleq;RcJA|SAP}&+t@@K=-&(Q6(Df<!+S%5M#KVs zD+HYI=q`z%8MAOXg$F0?IK3CN;}-DqL2%lmQ$;~DW&t{O0cUHT(+Esp;k08L&egTm zJ2;43ST}P6_zre{2mA11cvCR|ySe>1FoDJJMpgi7bIWUB0*m1-SOC=KX05;k7DL;m z0DqYKZB>E#sqlK1^;!#fj7AH29h%u;))3vtYRStwyujrroQf*0ms?dY*WSDJx4#Gf z_wU6z$={-_ycE4(>{rnfd`i>?%KR16SAySyd>tV9hD7bheaq>K1PV%3g}%zP-{a?> zgZB0;oXIwa1!l!vipF^=ocI$fa3n{Ts(+OT#I%=WaVf@ryHC|n`3DijF->tykjMkI z>0MsUh=h{}i9mj%LIUzFVG89R!Qrx>+OEf>R^0|l4XM2<Nut!{x>SPSgRl4ZW(m-w zi4pnWhoIQ*7>U@YoDwVyiW^IB4aR!Fl%fpD>7|}3)X|a?zZaMiwNPvv>>ny#OnUBw zrMn-3t|jG08nj&sQkY?kl@InGaU$7K1|JO0MbQ&6P+4r`s$?t@-tvMfRy)^?{JGeR z86Rv&=-q&FjYHVkmQhg^jtVS<$sL@<?o;(Zi{I&`Va?unYEgUG@vTT;OPgUtqitJJ T$7!7>{ro``+2a2J@}KaTsY(JS delta 3434 zcmV-w4VCi3Fu)s-cYj>la_cA-eb-k|HT~G}On0iLrn;ZfIXUOJ=3LUb)I6%Ngl#&q zPy<ThznO2DFB@)>kt_r#vePe#fW6n+zy=}Mctc^eEfELR{+Aue>jHs)09fI_Z%n~r z%Xh`!bkW5R;38=YGEy=QOU~T7t%w^Xha9>AWa+qWh#6hEP=CC|$D6B%$Ge+*37i#> zG4PWgHw19;;pXP{`up9DM24JGys9~x4SjiibN%?+&A<M68u))-efsp}M5f=~egAs% z@cHKP{xtZHx1a8=KFn+U!`<V<=h=Kjp`FG-H1uh_Lt~%7H8k*v{6d4z=M?gf5PC+G zBkGxtI<lTQsedc`L8Pv27PsRKkVK&`&@0MhbI7_Nz9so-Xx6wuDXZ6O{3@v`N-9yX zKMmOi@pK2PbX}Jj$D}-U<#vc}P%5j)PEEz&GfLZ6>a0W`h5h$>p)jkz8QjGt%T*is zhv6GY@a5?^GMH^qN@`x!{LtI2lg-3}OU%<vK~}IxaetYk*Lg_LE2l8yg!-0I2SLxA zWeMWTHW7|)DNSNpEH7_tR=3lRkxpS-lyHZ<)7b2+4Dp~zW^UqZQWU72NYB^I2i|j3 zeT^WmDVpTSecc#;KGI!Ye$-i0ynEP^x0|rc&~?2+dQth;u;l|n)it5$uVc|;C1CRR zFzB`}7=QXkFhdWt=G%`2+?>~@4Jr}E*`(3a*NquHN4m?)PdaPL=oz-eW5Ra7qBy50 zBa>=BpdQlzL+YXpU?CoHc3Z#_eO4lOLS#ytIPLQoR`Oj<$^x#CaLp+VGYAmVG{XSP z%Wn@~MgA9=KczuN?D7~?K`Q#BnD&l2QjTegV}E5aXPCSxZQwA*X9Z?(THS5MsCsu4 zMV*)TSCi^{Y=DRP&l<656GB%rrnWuJ<>l{>H}~JC)OW}R4>?7gPOETNGp52l&E@5t z%$QQ)Asam8tSWGccGa}ue5@T)@d54S<-N-MdV90Hp3(#^i#(=qx!{=81YOOT>i0C4 zmw#&2Go{i)Hu%Wv-L&djdQp|N^+Z)4zfP*?5&I}@&2fP_MiILOtx9<14k}sutv89$ zQhy`+VqHMaQJI#Aag?VyR{4<bz56g>mQ65;iA5<a#Nz(_Oc2E-s|Z7>u!E;$tr~`% z*Pu0pyJ;S&HK(|S(NzCZ%+ehxIm#d{$bYoD+m=i;p(K-CVLH3qQzyIh3GG=5%RH^* z7pHVg#Li|4-&fPRUnx%MeXY2n2tK9RJ}YpRiZ7v>GP4;LUv0~jC>t^;hu!KAF>X>$ zMDtUWR(rmE7duqvr@Hsw)f#ougLW%K(`P6q$yG7c9j96##^Z_*(I)ym{65^MtbccB zr^cBW2`ltY6*rr}=cbhpCm<Ey)8dnhd*(JH<(bGDDf2&#IkO(seGBj6R_8_>Z3oyg zhY|gw>G5O3a~TzWq<k)mLQkG2*~9)6qXc=JQww>#&<N>&QNBN-Ih}KH|N9q_8R7O` zGv3U<faPAl!ex+zT2Pj$`0(d@Pk+~*!L+Escm{jI$Ke@_2_1qno1&tkXr1Y?tk&(` zSK!Ph$QOQ=a*E!AS+91*SUEj+=(dYPD{7uPXo*n<vu$9s$tZO=XC}U>L<W~j_V?ga zAF*mciSHL7c`h4d;R*_WLa&$wY@p0dt1EwAa$1+NiPMbaXY1<|c5Ub&4SyY$2f1d= zr0w>vh(<K0b2iJr2cq9Wf$J8pE%O*5oyqh1dy^$p%nM`}-0oUDxUr3R1u=cWatmcE zmO23w;4tABlXzp}a+5M-c3M-47%85;tKa~H;M%$o?p4WsQ*dY57UE$Z(`L14VFZGm zF3LDm6=c6d<=CfxE8}i>d4E3bMv$fBx?<hUxv<)UVWh&a7GlhDQqzwjInuSqz=4P| zH~lFWZxxm4?w_|Y{5&(me-Pem`HYQX11gJOR0f;JTv7Km-lTOUyj&aXny5lhg(rvj zStI6gv{p<$q(6PIJ8IO*F1`===Mw;nm1v%dAw7Pq=)>d`;&QY(8-EFROSTf)ZBp+{ z;82KWSp4<}LA3lrUYx^Vlq!hhvBYaacPC^B8qf36$Wi7V-QipT3)z4*to0p9H)6%b zPzYQj8-u&JqMKU22ybP|-P$geIQWIU3YJ5fr_kiH(JiT&SYI&Tro+_3{XLcUd)0<p z%-V_uSH3y8@Wj(uzJI7lmF8#->w?cllB0}JIDSuNNIJxi)i&Zj2-t{5po;%xo|^kl zqo4K;77zRuvJ|~?RQk@or@>>E+84xT2e&SU&4Z116`?$+GmU6SXSSm7aKKu@><MlH zP4938kjN#*X&J=m$hN1Wpy<yUi%+bkx#nXi$YwLT5;_tM>3`uy^o#f(^bYG_2Z1g! zDzY<Jgv^Q);B&O9Hz&eIloEO(Y<0M_6>ru8>FM2dfA38!Cnvz=nw+SG7~)^hGdh*y zFfP7(cBk)x7f9RBqWNoK&{I-oXl{x-R3Rvq`30YSZm^-{P-srXPh|a~T%}Z{Sqh7S zyr4XA`Fz9(SAV4Dfbsm7oMhlaD9Q{~4zA>5Vk30~`#*F8u5Opf@K4w{^}4?b0or78 zv<v8jXn*Dg$fsoFmnX&{!cFP@RK+Oj@j)f%WI50a0`m}PEqB1xX$L49=|EfP04EY| ztT@a_>?Ud}2-^<c0w&?d?u5~>SgxLOhdcGx&zwO8QGZ(JD?mScf$ekh;HnL(+O)y- z>S3hvHB{4;n5L#1asU}J=}?-C1r3h#_kN!N9MVbXefVI4{vCQ?Wh9Z;h$Bb`Zy<x? z{3P-^Z-fbhH-W)%eiDCsHo^+m-CHnNluWs}8ext5;wx|zCb8E@Bdi&|f(ix-l9<b# zF%}$M3V#K?%p~fXWd!xW{Y=o$O`@(mMvxEOl?46VB>J*pgaMA5ieRWDiMjw7VZ^}A zLD0`lA}-~{Q1|chfo@(B7Lp#!*gXUQZ8C|!?=^w5e_IS#8A;S-qY32JEhV6zlbCq_ z35<t(0zf?`&JgQy@WA!v_-%n826v=47%UpH0)JOXj$99+kuumOmAC@B#=(LXtYgvS zA?VIXP%u=I1c#=@koQJ!KQ|ePWf`Z5m=?xD+yXg<W4XXlcH)no3Z=ASt$=<`CI$Y4 z@{LCh0cUaIh!=A6c0-E5A?<<^Zymj@%|bw1;pXjRl74iPIo4)lT6T5Z?P4i5rK#EB zkAGtADQogxX4eLhC)USK7G11y^z)Fqcq7VVOtr)5K{`k_ug)gz0Nz!kL}S}fN0JS2 zllA~F3b;o!245>^8vX2}s0b+{2&A}~OvKY5c#b$4U8MEo8b8#jQOJ1gDJ0F|SSW{) zH{(|)#u+L~<W@4OSun{##Wa7L9;DiorGF1vh?r4WmM#priTXUYPkb!e?zYN=6F1+N zDBqwETqA0FP$zu2=^wacJOW)SbU1XDm+1*^HU&yK<|KyCQtPr>1y1h3-u3seOnZG$ zOIb5E8v$F$N`yixS+g}KAHBE}&|^5%KN%(?e^^$YiH8yO)vTtna=pYXhAy~2N`H|G zMOLyJSUq}iC!hy$sDD0K_0L)W3?jABVL^~x>1HWu=M9o@$ew)AQpS9$iYF(`LY-pb zHO^p&2M(y13AmHzrjxs(=E`CyU01~{`8#=Ty5N=yXfF9N;PpKvbol7%xtqER;6iAx z0s#4xI0LxH0ql@cU=oQta(NuA6Mx?fnp83v)9OU(#5sc|H4L)IU{S1K_Qd>ft*G(9 zCh0)d``><8E!V7RS4Ru|gCJq7{T=%d;Euc@=;S$w#vQLK=5#$*EV&nUAXFK2^OE6U z0Vj0W=EOkkvLCP$M+%yh3s_&E2!Pds?NU=bk3%lhea={TOG#5%86H0UuYacM%5E^B ztz}nE=XfZagVWercJ(^v`C)ZcGi@}jt!2H$!ziT}jSUu;Lc0&f+{2_%X$D+$&3cO6 zY1UIeE2P@b*NdnIhjcO?tQ&xBhxG!JNX7$115oWyQeYCvcxY1qsvQIcOu~h_gM1re zsgg5(DIXP{V_UDYfyZc;$bWEt6S-xH9&)whb(>C5eSvmji0kcE^_J_#ZvE{C@W1~i z>w>&wTS)nc*k|!3xF}R$B6&vi8KW0K)&R;A5*6qALg<qO0<{`KpJd&?p@&~UH#`fK z`WM2?j=L1i$xb-&1*=ddSC(p)0OGS3mw7GLe*24Nsr;kJ;)w20j(?$O0~JgzuNFb! zBtinn8(k8RZ!uFSzk>d@p9+IxQmdT<rH0gA)wodVa$l-Iui(?SZ)OY7)QMqP@FS?U zdxkT1K`2JTpxoMevl!b2(+-wUeqQQNg*IDqqSp#htTu{`gZ--V#mAknbpIn5+ETXa zpzTtS!VF`qEWqU&QDY>dBXN1uz`ez}=z1&`D#=FfO3L4#=y*YstDS2pACrTa@j**M zM+3?|4q<0kMrAcPDv=122RN(!MN^=~zv``F&2E)?)Nb|f;dQhbf;8K96?LB0x!2{l MqRW#116ZEXq$@eMJpcdz diff --git a/py-modindex.html b/py-modindex.html index 058464641..c3e8451b2 100644 --- a/py-modindex.html +++ b/py-modindex.html @@ -360,6 +360,46 @@ <h1>Python Module Index</h1> <td>    <a href="api/graphnet.data.dataconverter.html#module-graphnet.data.dataconverter"><code class="xref">graphnet.data.dataconverter</code></a></td><td> <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.data.dataloader.html#module-graphnet.data.dataloader"><code class="xref">graphnet.data.dataloader</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.data.dataset.html#module-graphnet.data.dataset"><code class="xref">graphnet.data.dataset</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.data.dataset.dataset.html#module-graphnet.data.dataset.dataset"><code class="xref">graphnet.data.dataset.dataset</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.data.dataset.parquet.html#module-graphnet.data.dataset.parquet"><code class="xref">graphnet.data.dataset.parquet</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.data.dataset.parquet.parquet_dataset.html#module-graphnet.data.dataset.parquet.parquet_dataset"><code class="xref">graphnet.data.dataset.parquet.parquet_dataset</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.data.dataset.sqlite.html#module-graphnet.data.dataset.sqlite"><code class="xref">graphnet.data.dataset.sqlite</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.data.dataset.sqlite.sqlite_dataset.html#module-graphnet.data.dataset.sqlite.sqlite_dataset"><code class="xref">graphnet.data.dataset.sqlite.sqlite_dataset</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html#module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed"><code class="xref">graphnet.data.dataset.sqlite.sqlite_dataset_perturbed</code></a></td><td> + <em></em></td></tr> <tr class="cg-1"> <td></td> <td>    @@ -455,6 +495,11 @@ <h1>Python Module Index</h1> <td>    <a href="api/graphnet.data.parquet.parquet_dataconverter.html#module-graphnet.data.parquet.parquet_dataconverter"><code class="xref">graphnet.data.parquet.parquet_dataconverter</code></a></td><td> <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.data.pipeline.html#module-graphnet.data.pipeline"><code class="xref">graphnet.data.pipeline</code></a></td><td> + <em></em></td></tr> <tr class="cg-1"> <td></td> <td>    @@ -495,6 +540,156 @@ <h1>Python Module Index</h1> <td>    <a href="api/graphnet.deployment.html#module-graphnet.deployment"><code class="xref">graphnet.deployment</code></a></td><td> <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.deployment.i3modules.graphnet_module.html#module-graphnet.deployment.i3modules.graphnet_module"><code class="xref">graphnet.deployment.i3modules.graphnet_module</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.html#module-graphnet.models"><code class="xref">graphnet.models</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.coarsening.html#module-graphnet.models.coarsening"><code class="xref">graphnet.models.coarsening</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.components.html#module-graphnet.models.components"><code class="xref">graphnet.models.components</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.components.layers.html#module-graphnet.models.components.layers"><code class="xref">graphnet.models.components.layers</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.components.pool.html#module-graphnet.models.components.pool"><code class="xref">graphnet.models.components.pool</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.detector.html#module-graphnet.models.detector"><code class="xref">graphnet.models.detector</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.detector.detector.html#module-graphnet.models.detector.detector"><code class="xref">graphnet.models.detector.detector</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.detector.icecube.html#module-graphnet.models.detector.icecube"><code class="xref">graphnet.models.detector.icecube</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.detector.prometheus.html#module-graphnet.models.detector.prometheus"><code class="xref">graphnet.models.detector.prometheus</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.gnn.html#module-graphnet.models.gnn"><code class="xref">graphnet.models.gnn</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.gnn.convnet.html#module-graphnet.models.gnn.convnet"><code class="xref">graphnet.models.gnn.convnet</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.gnn.dynedge.html#module-graphnet.models.gnn.dynedge"><code class="xref">graphnet.models.gnn.dynedge</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.gnn.dynedge_jinst.html#module-graphnet.models.gnn.dynedge_jinst"><code class="xref">graphnet.models.gnn.dynedge_jinst</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.gnn.dynedge_kaggle_tito.html#module-graphnet.models.gnn.dynedge_kaggle_tito"><code class="xref">graphnet.models.gnn.dynedge_kaggle_tito</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.gnn.gnn.html#module-graphnet.models.gnn.gnn"><code class="xref">graphnet.models.gnn.gnn</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.graphs.html#module-graphnet.models.graphs"><code class="xref">graphnet.models.graphs</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.graphs.edges.html#module-graphnet.models.graphs.edges"><code class="xref">graphnet.models.graphs.edges</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.graphs.edges.edges.html#module-graphnet.models.graphs.edges.edges"><code class="xref">graphnet.models.graphs.edges.edges</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.graphs.graph_definition.html#module-graphnet.models.graphs.graph_definition"><code class="xref">graphnet.models.graphs.graph_definition</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.graphs.graphs.html#module-graphnet.models.graphs.graphs"><code class="xref">graphnet.models.graphs.graphs</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.graphs.nodes.html#module-graphnet.models.graphs.nodes"><code class="xref">graphnet.models.graphs.nodes</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.graphs.nodes.nodes.html#module-graphnet.models.graphs.nodes.nodes"><code class="xref">graphnet.models.graphs.nodes.nodes</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.model.html#module-graphnet.models.model"><code class="xref">graphnet.models.model</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.standard_model.html#module-graphnet.models.standard_model"><code class="xref">graphnet.models.standard_model</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.task.html#module-graphnet.models.task"><code class="xref">graphnet.models.task</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.task.classification.html#module-graphnet.models.task.classification"><code class="xref">graphnet.models.task.classification</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.task.reconstruction.html#module-graphnet.models.task.reconstruction"><code class="xref">graphnet.models.task.reconstruction</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.task.task.html#module-graphnet.models.task.task"><code class="xref">graphnet.models.task.task</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.models.utils.html#module-graphnet.models.utils"><code class="xref">graphnet.models.utils</code></a></td><td> + <em></em></td></tr> <tr class="cg-1"> <td></td> <td>    @@ -515,6 +710,26 @@ <h1>Python Module Index</h1> <td>    <a href="api/graphnet.training.html#module-graphnet.training"><code class="xref">graphnet.training</code></a></td><td> <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.training.callbacks.html#module-graphnet.training.callbacks"><code class="xref">graphnet.training.callbacks</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.training.labels.html#module-graphnet.training.labels"><code class="xref">graphnet.training.labels</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.training.loss_functions.html#module-graphnet.training.loss_functions"><code class="xref">graphnet.training.loss_functions</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.training.utils.html#module-graphnet.training.utils"><code class="xref">graphnet.training.utils</code></a></td><td> + <em></em></td></tr> <tr class="cg-1"> <td></td> <td>    @@ -530,6 +745,41 @@ <h1>Python Module Index</h1> <td>    <a href="api/graphnet.utilities.argparse.html#module-graphnet.utilities.argparse"><code class="xref">graphnet.utilities.argparse</code></a></td><td> <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.utilities.config.html#module-graphnet.utilities.config"><code class="xref">graphnet.utilities.config</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.utilities.config.base_config.html#module-graphnet.utilities.config.base_config"><code class="xref">graphnet.utilities.config.base_config</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.utilities.config.configurable.html#module-graphnet.utilities.config.configurable"><code class="xref">graphnet.utilities.config.configurable</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.utilities.config.dataset_config.html#module-graphnet.utilities.config.dataset_config"><code class="xref">graphnet.utilities.config.dataset_config</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.utilities.config.model_config.html#module-graphnet.utilities.config.model_config"><code class="xref">graphnet.utilities.config.model_config</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.utilities.config.parsing.html#module-graphnet.utilities.config.parsing"><code class="xref">graphnet.utilities.config.parsing</code></a></td><td> + <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.utilities.config.training_config.html#module-graphnet.utilities.config.training_config"><code class="xref">graphnet.utilities.config.training_config</code></a></td><td> + <em></em></td></tr> <tr class="cg-1"> <td></td> <td>    @@ -550,6 +800,11 @@ <h1>Python Module Index</h1> <td>    <a href="api/graphnet.utilities.logging.html#module-graphnet.utilities.logging"><code class="xref">graphnet.utilities.logging</code></a></td><td> <em></em></td></tr> + <tr class="cg-1"> + <td></td> + <td>    + <a href="api/graphnet.utilities.maths.html#module-graphnet.utilities.maths"><code class="xref">graphnet.utilities.maths</code></a></td><td> + <em></em></td></tr> </table> @@ -575,7 +830,7 @@ <h1>Python Module Index</h1> </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/search.html b/search.html index 32901805c..4199f1026 100644 --- a/search.html +++ b/search.html @@ -361,7 +361,7 @@ <h1 id="search-documentation">Search</h1> </div> Created using - <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.5. + <a href="http://www.sphinx-doc.org/">Sphinx</a> 7.2.6. and <a href="https://github.com/bashtage/sphinx-material/">Material for Sphinx</a> diff --git a/searchindex.js b/searchindex.js index 4df38d653..200c93687 100644 --- a/searchindex.js +++ b/searchindex.js @@ -1 +1 @@ -Search.setIndex({"docnames": ["about", "api/graphnet", "api/graphnet.constants", "api/graphnet.data", "api/graphnet.data.constants", "api/graphnet.data.dataconverter", "api/graphnet.data.dataloader", "api/graphnet.data.dataset", "api/graphnet.data.dataset.dataset", "api/graphnet.data.dataset.parquet", "api/graphnet.data.dataset.parquet.parquet_dataset", "api/graphnet.data.dataset.sqlite", "api/graphnet.data.dataset.sqlite.sqlite_dataset", "api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed", "api/graphnet.data.extractors", "api/graphnet.data.extractors.i3extractor", "api/graphnet.data.extractors.i3featureextractor", "api/graphnet.data.extractors.i3genericextractor", "api/graphnet.data.extractors.i3hybridrecoextractor", "api/graphnet.data.extractors.i3ntmuonlabelsextractor", "api/graphnet.data.extractors.i3particleextractor", "api/graphnet.data.extractors.i3pisaextractor", "api/graphnet.data.extractors.i3quesoextractor", "api/graphnet.data.extractors.i3retroextractor", "api/graphnet.data.extractors.i3splinempeextractor", "api/graphnet.data.extractors.i3truthextractor", "api/graphnet.data.extractors.i3tumextractor", "api/graphnet.data.extractors.utilities", "api/graphnet.data.extractors.utilities.collections", "api/graphnet.data.extractors.utilities.frames", "api/graphnet.data.extractors.utilities.types", "api/graphnet.data.parquet", "api/graphnet.data.parquet.parquet_dataconverter", "api/graphnet.data.pipeline", "api/graphnet.data.sqlite", "api/graphnet.data.sqlite.sqlite_dataconverter", "api/graphnet.data.sqlite.sqlite_utilities", "api/graphnet.data.utilities", "api/graphnet.data.utilities.parquet_to_sqlite", "api/graphnet.data.utilities.random", "api/graphnet.data.utilities.string_selection_resolver", "api/graphnet.deployment", "api/graphnet.deployment.i3modules", "api/graphnet.deployment.i3modules.deployer", "api/graphnet.deployment.i3modules.graphnet_module", "api/graphnet.models", "api/graphnet.models.coarsening", "api/graphnet.models.components", "api/graphnet.models.components.layers", "api/graphnet.models.components.pool", "api/graphnet.models.detector", "api/graphnet.models.detector.detector", "api/graphnet.models.detector.icecube", "api/graphnet.models.detector.prometheus", "api/graphnet.models.gnn", "api/graphnet.models.gnn.convnet", "api/graphnet.models.gnn.dynedge", "api/graphnet.models.gnn.dynedge_jinst", "api/graphnet.models.gnn.dynedge_kaggle_tito", "api/graphnet.models.gnn.gnn", "api/graphnet.models.graphs", "api/graphnet.models.graphs.edges", "api/graphnet.models.graphs.edges.edges", "api/graphnet.models.graphs.graph_definition", "api/graphnet.models.graphs.graphs", "api/graphnet.models.graphs.nodes", "api/graphnet.models.graphs.nodes.nodes", "api/graphnet.models.model", "api/graphnet.models.standard_model", "api/graphnet.models.task", "api/graphnet.models.task.classification", "api/graphnet.models.task.reconstruction", "api/graphnet.models.task.task", "api/graphnet.models.utils", "api/graphnet.pisa", "api/graphnet.pisa.fitting", "api/graphnet.pisa.plotting", "api/graphnet.training", "api/graphnet.training.callbacks", "api/graphnet.training.labels", "api/graphnet.training.loss_functions", "api/graphnet.training.utils", "api/graphnet.training.weight_fitting", "api/graphnet.utilities", "api/graphnet.utilities.argparse", "api/graphnet.utilities.config", "api/graphnet.utilities.config.base_config", "api/graphnet.utilities.config.configurable", "api/graphnet.utilities.config.dataset_config", "api/graphnet.utilities.config.model_config", "api/graphnet.utilities.config.parsing", "api/graphnet.utilities.config.training_config", "api/graphnet.utilities.decorators", "api/graphnet.utilities.filesys", "api/graphnet.utilities.imports", "api/graphnet.utilities.logging", "api/graphnet.utilities.maths", "api/modules", "contribute", "index", "install"], "filenames": ["about.md", "api/graphnet.rst", "api/graphnet.constants.rst", "api/graphnet.data.rst", "api/graphnet.data.constants.rst", "api/graphnet.data.dataconverter.rst", "api/graphnet.data.dataloader.rst", "api/graphnet.data.dataset.rst", "api/graphnet.data.dataset.dataset.rst", "api/graphnet.data.dataset.parquet.rst", "api/graphnet.data.dataset.parquet.parquet_dataset.rst", "api/graphnet.data.dataset.sqlite.rst", "api/graphnet.data.dataset.sqlite.sqlite_dataset.rst", "api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.rst", "api/graphnet.data.extractors.rst", "api/graphnet.data.extractors.i3extractor.rst", "api/graphnet.data.extractors.i3featureextractor.rst", "api/graphnet.data.extractors.i3genericextractor.rst", "api/graphnet.data.extractors.i3hybridrecoextractor.rst", "api/graphnet.data.extractors.i3ntmuonlabelsextractor.rst", "api/graphnet.data.extractors.i3particleextractor.rst", "api/graphnet.data.extractors.i3pisaextractor.rst", "api/graphnet.data.extractors.i3quesoextractor.rst", "api/graphnet.data.extractors.i3retroextractor.rst", "api/graphnet.data.extractors.i3splinempeextractor.rst", "api/graphnet.data.extractors.i3truthextractor.rst", "api/graphnet.data.extractors.i3tumextractor.rst", "api/graphnet.data.extractors.utilities.rst", "api/graphnet.data.extractors.utilities.collections.rst", "api/graphnet.data.extractors.utilities.frames.rst", "api/graphnet.data.extractors.utilities.types.rst", "api/graphnet.data.parquet.rst", "api/graphnet.data.parquet.parquet_dataconverter.rst", "api/graphnet.data.pipeline.rst", "api/graphnet.data.sqlite.rst", "api/graphnet.data.sqlite.sqlite_dataconverter.rst", "api/graphnet.data.sqlite.sqlite_utilities.rst", "api/graphnet.data.utilities.rst", "api/graphnet.data.utilities.parquet_to_sqlite.rst", "api/graphnet.data.utilities.random.rst", "api/graphnet.data.utilities.string_selection_resolver.rst", "api/graphnet.deployment.rst", "api/graphnet.deployment.i3modules.rst", "api/graphnet.deployment.i3modules.deployer.rst", "api/graphnet.deployment.i3modules.graphnet_module.rst", "api/graphnet.models.rst", "api/graphnet.models.coarsening.rst", "api/graphnet.models.components.rst", "api/graphnet.models.components.layers.rst", "api/graphnet.models.components.pool.rst", "api/graphnet.models.detector.rst", "api/graphnet.models.detector.detector.rst", "api/graphnet.models.detector.icecube.rst", "api/graphnet.models.detector.prometheus.rst", "api/graphnet.models.gnn.rst", "api/graphnet.models.gnn.convnet.rst", "api/graphnet.models.gnn.dynedge.rst", "api/graphnet.models.gnn.dynedge_jinst.rst", "api/graphnet.models.gnn.dynedge_kaggle_tito.rst", "api/graphnet.models.gnn.gnn.rst", "api/graphnet.models.graphs.rst", "api/graphnet.models.graphs.edges.rst", "api/graphnet.models.graphs.edges.edges.rst", "api/graphnet.models.graphs.graph_definition.rst", "api/graphnet.models.graphs.graphs.rst", "api/graphnet.models.graphs.nodes.rst", "api/graphnet.models.graphs.nodes.nodes.rst", "api/graphnet.models.model.rst", "api/graphnet.models.standard_model.rst", "api/graphnet.models.task.rst", "api/graphnet.models.task.classification.rst", "api/graphnet.models.task.reconstruction.rst", "api/graphnet.models.task.task.rst", "api/graphnet.models.utils.rst", "api/graphnet.pisa.rst", "api/graphnet.pisa.fitting.rst", "api/graphnet.pisa.plotting.rst", "api/graphnet.training.rst", "api/graphnet.training.callbacks.rst", "api/graphnet.training.labels.rst", "api/graphnet.training.loss_functions.rst", "api/graphnet.training.utils.rst", "api/graphnet.training.weight_fitting.rst", "api/graphnet.utilities.rst", "api/graphnet.utilities.argparse.rst", "api/graphnet.utilities.config.rst", "api/graphnet.utilities.config.base_config.rst", "api/graphnet.utilities.config.configurable.rst", "api/graphnet.utilities.config.dataset_config.rst", "api/graphnet.utilities.config.model_config.rst", "api/graphnet.utilities.config.parsing.rst", "api/graphnet.utilities.config.training_config.rst", "api/graphnet.utilities.decorators.rst", "api/graphnet.utilities.filesys.rst", "api/graphnet.utilities.imports.rst", "api/graphnet.utilities.logging.rst", "api/graphnet.utilities.maths.rst", "api/modules.rst", "contribute.md", "index.rst", "install.md"], "titles": ["About", "API", "constants", "data", "constants", "dataconverter", "dataloader", "dataset", "dataset", "parquet", "parquet_dataset", "sqlite", "sqlite_dataset", "sqlite_dataset_perturbed", "extractors", "i3extractor", "i3featureextractor", "i3genericextractor", "i3hybridrecoextractor", "i3ntmuonlabelsextractor", "i3particleextractor", "i3pisaextractor", "i3quesoextractor", "i3retroextractor", "i3splinempeextractor", "i3truthextractor", "i3tumextractor", "utilities", "collections", "frames", "types", "parquet", "parquet_dataconverter", "pipeline", "sqlite", "sqlite_dataconverter", "sqlite_utilities", "utilities", "parquet_to_sqlite", "random", "string_selection_resolver", "deployment", "i3modules", "deployer", "graphnet_module", "models", "coarsening", "components", "layers", "pool", "detector", "detector", "icecube", "prometheus", "gnn", "convnet", "dynedge", "dynedge_jinst", "dynedge_kaggle_tito", "gnn", "graphs", "edges", "edges", "graph_definition", "graphs", "nodes", "nodes", "model", "standard_model", "task", "classification", "reconstruction", "task", "utils", "pisa", "fitting", "plotting", "training", "callbacks", "labels", "loss_functions", "utils", "weight_fitting", "utilities", "argparse", "config", "base_config", "configurable", "dataset_config", "model_config", "parsing", "training_config", "decorators", "filesys", "imports", "logging", "maths", "src", "Contribute", "About", "Install"], "terms": {"graphnet": [0, 1, 2, 3, 4, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 35, 36, 37, 38, 39, 40, 41, 75, 76, 77, 82, 83, 84, 93, 94, 95, 98, 99, 100], "i": [0, 1, 15, 17, 28, 29, 30, 35, 36, 39, 40, 76, 82, 84, 93, 94, 95, 98, 99, 100], "an": [0, 5, 30, 32, 35, 40, 93, 95, 98, 99, 100], "open": [0, 98, 99], "sourc": [0, 4, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 35, 36, 38, 39, 40, 75, 76, 82, 84, 93, 94, 95, 98, 99], "python": [0, 1, 5, 14, 15, 17, 28, 30, 98, 99, 100], "framework": [0, 99], "aim": [0, 1, 98, 99], "provid": [0, 1, 98, 99, 100], "high": [0, 99], "qualiti": [0, 99], "user": [0, 99, 100], "friendli": [0, 99], "end": [0, 1, 5, 32, 35, 99], "function": [0, 5, 30, 36, 39, 75, 76, 83, 93, 94, 99], "perform": [0, 99], "reconstruct": [0, 1, 16, 18, 19, 23, 24, 26, 41, 45, 69, 99], "task": [0, 1, 45, 98, 99], "neutrino": [0, 1, 75, 99], "telescop": [0, 1, 99], "us": [0, 1, 2, 4, 5, 15, 20, 25, 27, 28, 32, 35, 36, 37, 38, 40, 41, 75, 82, 83, 84, 94, 95, 98, 99, 100], "graph": [0, 1, 45, 98, 99], "neural": [0, 1, 99], "network": [0, 1, 99], "gnn": [0, 1, 45, 99, 100], "make": [0, 5, 82, 98, 99, 100], "fast": [0, 99, 100], "easi": [0, 99], "train": [0, 1, 40, 41, 82, 84, 97, 99, 100], "complex": [0, 99], "model": [0, 1, 41, 76, 77, 84, 97, 99, 100], "can": [0, 1, 15, 17, 20, 38, 75, 76, 82, 84, 98, 99, 100], "event": [0, 1, 22, 36, 38, 40, 75, 82, 99], "state": [0, 99], "art": [0, 99], "arbitrari": [0, 99], "detector": [0, 1, 25, 45, 99], "configur": [0, 1, 75, 83, 85, 95, 99], "infer": [0, 1, 41, 99, 100], "time": [0, 4, 36, 95, 99, 100], "ar": [0, 1, 4, 5, 17, 30, 32, 35, 38, 40, 75, 82, 98, 99, 100], "order": [0, 28, 99], "magnitud": [0, 99], "faster": [0, 99], "than": [0, 95, 99], "tradit": [0, 99], "techniqu": [0, 99], "common": [0, 1, 92, 94, 99], "ml": [0, 1, 99], "develop": [0, 1, 98, 99, 100], "physicist": [0, 1, 99], "wish": [0, 98, 99], "tool": [0, 1, 99], "research": [0, 99], "By": [0, 38, 99], "unit": [0, 5, 94, 98, 99], "both": [0, 17, 76, 99], "group": [0, 5, 32, 35, 99], "increas": [0, 99], "longev": [0, 99], "usabl": [0, 99], "individu": [0, 5, 99], "code": [0, 25, 36, 99], "contribut": [0, 99, 100], "from": [0, 1, 14, 15, 17, 19, 20, 22, 28, 29, 30, 35, 38, 76, 95, 98, 99, 100], "build": [0, 1, 99], "gener": [0, 5, 17, 99], "reusabl": [0, 99], "softwar": [0, 99], "packag": [0, 1, 39, 93, 94, 98, 99, 100], "base": [0, 4, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35, 38, 40, 75, 82, 84, 94, 95, 99], "engin": [0, 99], "best": [0, 98, 99], "practic": [0, 98, 99], "lower": [0, 76, 99], "technic": [0, 99], "threshold": [0, 99], "most": [0, 1, 40, 99, 100], "scientif": [0, 1, 99], "problem": [0, 98, 99], "The": [0, 5, 28, 30, 35, 36, 75, 76, 99], "improv": [0, 1, 84, 99], "classif": [0, 1, 45, 69, 99], "yield": [0, 75, 99], "veri": [0, 40, 99], "accur": [0, 99], "e": [0, 1, 5, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 28, 30, 32, 35, 36, 40, 82, 95, 98, 99, 100], "g": [0, 1, 5, 25, 28, 30, 32, 35, 36, 40, 82, 95, 98, 99, 100], "low": [0, 99], "energi": [0, 4, 82, 99], "observ": [0, 99], "icecub": [0, 1, 16, 29, 30, 45, 50, 94, 99, 100], "here": [0, 98, 99, 100], "implement": [0, 1, 5, 15, 31, 32, 34, 35, 98, 99], "wa": [0, 99], "appli": [0, 15, 99], "oscil": [0, 74, 99], "lead": [0, 99], "signific": [0, 99], "angular": [0, 99], "rang": [0, 99], "relev": [0, 1, 30, 39, 93, 98, 99], "studi": [0, 99], "furthermor": [0, 99], "shown": [0, 99], "could": [0, 98, 99], "muon": [0, 19, 99], "v": [0, 99], "therebi": [0, 1, 99], "effici": [0, 99], "puriti": [0, 99], "sampl": [0, 40, 99], "analysi": [0, 99, 100], "similarli": [0, 30, 99], "ha": [0, 5, 30, 32, 35, 36, 93, 99, 100], "great": [0, 99], "point": [0, 24, 99], "analys": [0, 41, 74, 99], "final": [0, 99], "millisecond": [0, 99], "allow": [0, 41, 99, 100], "whole": [0, 99], "new": [0, 1, 35, 98, 99], "type": [0, 5, 14, 15, 27, 28, 29, 32, 35, 36, 38, 39, 40, 75, 76, 82, 84, 93, 94, 95, 98, 99], "cosmic": [0, 99], "alert": [0, 99], "which": [0, 15, 16, 25, 29, 40, 75, 84, 99, 100], "were": [0, 99], "previous": [0, 99], "unfeas": [0, 99], "possibl": [0, 28, 98, 99], "identifi": [0, 5, 25, 99], "10": [0, 84, 99], "tev": [0, 99], "monitor": [0, 99], "rate": [0, 99], "direct": [0, 99], "real": [0, 99], "thi": [0, 3, 5, 15, 17, 30, 32, 35, 36, 39, 75, 76, 82, 95, 98, 99, 100], "enabl": [0, 3, 99], "first": [0, 98, 99], "ever": [0, 99], "despit": [0, 99], "larg": [0, 99], "background": [0, 99], "origin": [0, 75, 99], "compris": [0, 99], "number": [0, 5, 32, 35, 40, 84, 99], "modul": [0, 3, 30, 41, 74, 77, 83, 94, 99], "necessari": [0, 28, 98, 99], "workflow": [0, 99], "ingest": [0, 1, 3, 99], "raw": [0, 99], "data": [0, 1, 4, 5, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 84, 94, 97, 99, 100], "domain": [0, 1, 3, 41, 99], "specif": [0, 1, 3, 5, 16, 30, 31, 32, 34, 35, 36, 41, 98, 99, 100], "format": [0, 1, 3, 5, 28, 32, 35, 76, 98, 99, 100], "deploi": [0, 1, 41, 99], "chain": [0, 1, 41, 99, 100], "illustr": [0, 98, 99], "figur": [0, 76, 99], "level": [0, 25, 36, 95, 99, 100], "overview": [0, 99], "typic": [0, 28, 99], "convert": [0, 1, 3, 5, 28, 32, 35, 38, 99, 100], "industri": [0, 3, 99], "standard": [0, 3, 4, 5, 32, 35, 40, 84, 98, 99], "intermedi": [0, 1, 3, 5, 32, 35, 99, 100], "file": [0, 1, 3, 5, 15, 28, 32, 35, 38, 39, 75, 84, 93, 95, 99, 100], "read": [0, 3, 28, 99, 100], "simpl": [0, 99], "physic": [0, 1, 15, 29, 30, 41, 99], "orient": [0, 99], "compon": [0, 1, 45, 99], "manag": [0, 15, 77, 99], "experi": [0, 1, 77, 99], "log": [0, 1, 77, 83, 99, 100], "deploy": [0, 1, 42, 97, 99], "modular": [0, 99], "subclass": [0, 99], "torch": [0, 94, 99, 100], "nn": [0, 99], "mean": [0, 5, 32, 35, 99], "onli": [0, 1, 75, 82, 94, 99, 100], "need": [0, 28, 99, 100], "import": [0, 1, 36, 83, 99], "few": [0, 98, 99], "exist": [0, 35, 36, 99], "purpos": [0, 99], "built": [0, 99], "them": [0, 1, 28, 75, 99, 100], "togeth": [0, 99], "form": [0, 99], "complet": [0, 99], "extend": [0, 1, 99], "suit": [0, 99], "through": [0, 99], "layer": [0, 45, 47, 99], "connect": [0, 99], "etc": [0, 95, 99], "optimis": [0, 1, 99], "differ": [0, 15, 98, 99, 100], "track": [0, 15, 19, 98, 99], "These": [0, 98, 99], "prepar": [0, 99], "satisfi": [0, 99], "o": [0, 99], "load": [0, 39, 99], "requir": [0, 21, 36, 99, 100], "when": [0, 5, 28, 32, 35, 36, 95, 98, 99, 100], "batch": [0, 84, 99], "do": [0, 98, 99, 100], "predict": [0, 20, 24, 26, 99], "either": [0, 99, 100], "contain": [0, 5, 28, 29, 32, 35, 82, 84, 99, 100], "imag": [0, 1, 98, 99, 100], "portabl": [0, 99], "depend": [0, 99, 100], "free": [0, 99], "split": [0, 99], "up": [0, 5, 32, 35, 98, 99, 100], "interfac": [0, 74, 99, 100], "block": [0, 1, 99], "pre": [0, 98, 99], "directli": [0, 15, 99], "while": [0, 17, 99], "continu": [0, 99], "expand": [0, 99], "": [0, 5, 15, 28, 35, 38, 82, 84, 95, 99, 100], "capabl": [0, 99], "project": [0, 98, 99], "receiv": [0, 99], "fund": [0, 99], "european": [0, 99], "union": [0, 17, 28, 30, 93, 99], "horizon": [0, 99], "2020": [0, 99], "innov": [0, 99], "programm": [0, 99], "under": [0, 99], "mari": [0, 99], "sk\u0142odowska": [0, 99], "curi": [0, 99], "grant": [0, 99], "agreement": [0, 98, 99], "No": [0, 99], "890778": [0, 99], "work": [0, 4, 29, 98, 99, 100], "rasmu": [0, 99], "\u00f8rs\u00f8e": [0, 99], "partli": [0, 99], "punch4nfdi": [0, 99], "consortium": [0, 99], "support": [0, 30, 98, 99, 100], "dfg": [0, 99], "nfdi": [0, 99], "39": [0, 99, 100], "1": [0, 5, 28, 32, 35, 40, 82, 99, 100], "germani": [0, 99], "conveni": [1, 98, 100], "collabor": 1, "solv": [1, 98], "It": [1, 28, 36, 98], "leverag": 1, "advanc": 1, "machin": [1, 100], "learn": [1, 100], "without": [1, 75, 100], "have": [1, 5, 17, 32, 35, 36, 40, 98, 100], "expert": 1, "themselv": 1, "acceler": 1, "area": 1, "phyic": 1, "design": 1, "principl": 1, "all": [1, 5, 15, 17, 32, 35, 36, 95, 98, 100], "streamlin": 1, "process": [1, 5, 15, 98, 100], "transform": [1, 82], "extens": [1, 93], "basic": 1, "across": [1, 2, 30, 37, 83, 84, 95], "variou": 1, "easili": 1, "architectur": 1, "main": [1, 98, 100], "featur": [1, 3, 4, 5, 16, 98], "i3": [1, 5, 15, 29, 30, 32, 35, 39, 93, 100], "more": [1, 36, 39, 95], "index": [1, 5, 30, 36], "sqlite": [1, 3, 7, 35, 36, 38, 100], "suitabl": 1, "plug": 1, "plai": 1, "abstract": [1, 5], "awai": 1, "detail": [1, 100], "expos": 1, "physicst": 1, "what": [1, 98], "i3modul": [1, 41], "includ": [1, 75, 98], "docker": 1, "run": [1, 38], "containeris": 1, "fashion": 1, "subpackag": [1, 3, 7, 14, 41, 45, 60, 83], "dataset": [1, 3, 19, 40, 84], "extractor": [1, 3, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 35], "parquet": [1, 3, 7, 32, 38, 100], "util": [1, 3, 14, 28, 29, 30, 36, 38, 39, 40, 45, 77, 84, 93, 94, 95, 97], "constant": [1, 3, 97], "dataconvert": [1, 3, 32, 35], "dataload": [1, 3], "pipelin": [1, 3], "coarsen": [1, 45], "standard_model": [1, 45], "pisa": [1, 21, 75, 76, 94, 97, 100], "fit": [1, 74, 76, 82], "plot": [1, 74], "callback": [1, 77], "label": [1, 19, 22, 76, 77], "loss_funct": [1, 77], "weight_fit": [1, 77], "config": [1, 40, 75, 83, 84], "argpars": [1, 83], "decor": [1, 5, 83, 94], "filesi": [1, 83], "math": [1, 83], "submodul": [1, 3, 7, 9, 11, 14, 27, 31, 34, 37, 42, 45, 47, 50, 54, 60, 61, 65, 69, 74, 77, 83, 85], "global": [2, 4], "i3extractor": [3, 5, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35], "i3featureextractor": [3, 4, 14, 35], "i3genericextractor": [3, 14, 35], "i3hybridrecoextractor": [3, 14], "i3ntmuonlabelsextractor": [3, 14], "i3particleextractor": [3, 14], "i3pisaextractor": [3, 14], "i3quesoextractor": [3, 14], "i3retroextractor": [3, 14], "i3splinempeextractor": [3, 14], "i3truthextractor": [3, 4, 14], "i3tumextractor": [3, 14], "parquet_dataconvert": [3, 31], "sqlite_dataconvert": [3, 34], "sqlite_util": [3, 34], "parquet_to_sqlit": [3, 37], "random": [3, 37, 40], "string_selection_resolv": [3, 37], "truth": [3, 4, 16, 25, 36, 82], "fileset": [3, 5], "init_global_index": [3, 5], "cache_output_fil": [3, 5], "class": [4, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 30, 31, 32, 34, 35, 38, 40, 75, 82, 84, 95, 98], "object": [4, 5, 15, 17, 28, 30, 75, 84, 95], "namespac": 4, "name": [4, 5, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 30, 32, 35, 36, 38, 75, 82, 84, 95, 98, 100], "icecube86": 4, "dom_x": 4, "dom_i": 4, "dom_z": 4, "dom_tim": 4, "charg": 4, "rde": 4, "pmt_area": 4, "deepcor": [4, 16], "upgrad": [4, 16, 100], "string": [4, 5, 28, 32, 35, 40], "pmt_number": 4, "dom_numb": 4, "pmt_dir_x": 4, "pmt_dir_i": 4, "pmt_dir_z": 4, "dom_typ": 4, "prometheu": [4, 45, 50], "sensor_pos_x": 4, "sensor_pos_i": 4, "sensor_pos_z": 4, "t": [4, 30, 36, 76, 100], "kaggl": 4, "x": [4, 5, 25, 32, 35, 76, 82], "y": [4, 25, 76, 100], "z": [4, 5, 25, 32, 35, 100], "auxiliari": 4, "energy_track": 4, "position_x": 4, "position_i": 4, "position_z": 4, "azimuth": 4, "zenith": 4, "pid": [4, 40], "elast": 4, "sim_typ": 4, "interaction_typ": 4, "interaction_tim": 4, "inelast": 4, "stopped_muon": 4, "injection_energi": 4, "injection_typ": 4, "injection_interaction_typ": 4, "injection_zenith": 4, "injection_azimuth": 4, "injection_bjorkenx": 4, "injection_bjorkeni": 4, "injection_position_x": 4, "injection_position_i": 4, "injection_position_z": 4, "injection_column_depth": 4, "primary_lepton_1_typ": 4, "primary_hadron_1_typ": 4, "primary_lepton_1_position_x": 4, "primary_lepton_1_position_i": 4, "primary_lepton_1_position_z": 4, "primary_hadron_1_position_x": 4, "primary_hadron_1_position_i": 4, "primary_hadron_1_position_z": 4, "primary_lepton_1_direction_theta": 4, "primary_lepton_1_direction_phi": 4, "primary_hadron_1_direction_theta": 4, "primary_hadron_1_direction_phi": 4, "primary_lepton_1_energi": 4, "primary_hadron_1_energi": 4, "total_energi": 4, "i3_fil": [5, 15], "str": [5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 35, 36, 38, 39, 40, 75, 82, 84, 93, 95], "gcd_file": [5, 15], "paramet": [5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 35, 36, 38, 39, 40, 75, 76, 82, 84, 93, 94, 95], "output_fil": [5, 32, 35], "global_index": 5, "avail": [5, 17, 94], "pool": [5, 45, 47], "worker": [5, 32, 35, 39, 84, 95], "return": [5, 15, 28, 29, 30, 32, 35, 36, 38, 39, 40, 75, 76, 82, 84, 93, 94, 95], "none": [5, 15, 17, 25, 29, 30, 32, 35, 36, 38, 40, 75, 82, 84, 93, 95], "synchron": 5, "list": [5, 15, 17, 25, 28, 30, 32, 35, 36, 38, 39, 40, 76, 82, 93, 95], "process_method": 5, "cach": 5, "output": [5, 32, 35, 38, 75, 82, 100], "typevar": 5, "f": 5, "bound": [5, 76], "callabl": [5, 30, 82, 94], "ani": [5, 28, 29, 30, 32, 35, 76, 82, 84, 95, 100], "outdir": [5, 32, 35, 38, 75], "gcd_rescu": [5, 32, 35, 93], "nb_files_to_batch": [5, 32, 35], "sequential_batch_pattern": [5, 32, 35], "input_file_batch_pattern": [5, 32, 35], "index_column": [5, 32, 35, 36, 40, 75, 82], "icetray_verbos": [5, 32, 35], "abc": [5, 15, 82], "logger": [5, 15, 38, 40, 82, 83, 95, 100], "construct": [5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35, 38, 40, 75, 82, 84, 95], "regular": [5, 30, 32, 35], "express": [5, 32, 35], "accord": [5, 32, 35], "match": [5, 32, 35, 82, 93], "certain": [5, 32, 35, 38, 75], "pattern": [5, 32, 35], "wildcard": [5, 32, 35], "same": [5, 30, 32, 35, 36, 95], "input": [5, 32, 35], "replac": [5, 32, 35], "period": [5, 32, 35], "special": [5, 17, 32, 35], "interpret": [5, 32, 35], "liter": [5, 32, 35], "charact": [5, 32, 35], "regex": [5, 32, 35], "For": [5, 30, 32, 35], "instanc": [5, 15, 25, 30, 32, 35, 75, 100], "A": [5, 32, 35, 75, 82, 100], "_": [5, 32, 35], "0": [5, 32, 35, 40, 75, 76], "9": [5, 32, 35], "5": [5, 32, 35, 40, 84, 100], "zst": [5, 32, 35], "find": [5, 32, 35, 93], "whose": [5, 32, 35], "one": [5, 32, 35, 36, 93, 98, 100], "capit": [5, 32, 35], "letter": [5, 32, 35], "follow": [5, 32, 35, 82, 98, 100], "underscor": [5, 32, 35], "five": [5, 32, 35], "upgrade_genie_step4_141020_a_000000": [5, 32, 35], "upgrade_genie_step4_141020_a_000001": [5, 32, 35], "upgrade_genie_step4_141020_a_000008": [5, 32, 35], "upgrade_genie_step4_141020_a_000009": [5, 32, 35], "would": [5, 32, 35, 98], "upgrade_genie_step4_141020_a_00000x": [5, 32, 35], "suffix": [5, 32, 35], "upgrade_genie_step4_141020_a_000010": [5, 32, 35], "separ": [5, 28, 32, 35, 100], "upgrade_genie_step4_141020_a_00001x": [5, 32, 35], "int": [5, 19, 22, 32, 35, 40, 75, 82, 84, 95], "properti": [5, 15, 20, 30, 95], "file_suffix": [5, 32, 35], "execut": [5, 36], "method": [5, 15, 27, 28, 29, 30, 32, 35, 82], "set": [5, 17, 98], "inherit": [5, 15, 30, 95], "path": [5, 36, 39, 75, 76, 84, 93, 100], "correspond": [5, 28, 30, 35, 39, 82, 93, 100], "gcd": [5, 15, 29, 39, 93], "save_data": [5, 32, 35], "save": [5, 15, 28, 32, 35, 36, 75, 82, 100], "ordereddict": [5, 32, 35], "extract": [5, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 35, 38, 39], "merge_fil": [5, 32, 35], "input_fil": [5, 32, 35], "merg": [5, 32, 35, 100], "result": [5, 32, 35, 100], "option": [5, 25, 32, 35, 75, 76, 82, 83, 84, 93, 100], "default": [5, 17, 25, 28, 32, 35, 36, 38, 75, 76, 82, 84, 93], "current": [5, 32, 35, 40, 98, 100], "rais": [5, 17, 32], "notimplementederror": [5, 32], "If": [5, 17, 32, 35, 75, 82, 98, 100], "been": [5, 32, 98], "backend": [5, 32, 35], "question": 5, "get_map_funct": 5, "nb_file": 5, "map": [5, 16, 17, 35, 36], "pure": [5, 14, 15, 17, 30], "multiprocess": [5, 100], "tupl": [5, 29, 30, 75, 76, 84], "parquet_dataset": [7, 9], "sqlite_dataset": [7, 11], "sqlite_dataset_perturb": [7, 11], "collect": [14, 15, 27], "i3fram": [14, 15, 17, 29, 30], "frame": [14, 15, 17, 27, 30, 35], "i3extractorcollect": [14, 15], "i3featureextractoricecube86": [14, 16], "i3featureextractoricecubedeepcor": [14, 16], "i3featureextractoricecubeupgrad": [14, 16], "i3pulsenoisetruthflagicecubeupgrad": [14, 16], "i3galacticplanehybridrecoextractor": [14, 18], "i3ntmuonlabelextractor": [14, 19], "i3splinempeicextractor": [14, 24], "inform": [15, 17, 25, 76], "should": [15, 28, 40, 98, 100], "__call__": 15, "icetrai": [15, 29, 30, 94], "keep": 15, "proven": 15, "tabl": [15, 35, 36, 75, 82], "set_fil": 15, "store": [15, 36, 75], "refer": 15, "being": 15, "get": [15, 29, 100], "multipl": [15, 95], "treat": 15, "singl": 15, "pulsemap": [16, 35], "puls": [16, 17, 29, 30, 35, 36], "seri": [16, 17, 29, 30, 36], "86": 16, "nois": [16, 29], "flag": 16, "ad": [16, 75], "kei": [17, 28, 29, 30, 35, 36], "exclude_kei": 17, "dynam": 17, "pars": [17, 76, 83, 84, 85], "call": [17, 30, 35, 75, 82, 95], "tri": [17, 30], "automat": [17, 98], "cast": [17, 30], "done": [17, 95, 98], "recurs": [17, 30, 93], "each": [17, 28, 30, 36, 38, 39, 75, 76, 93], "look": [17, 100], "member": [17, 30, 95], "variabl": [17, 30, 82, 95], "signatur": [17, 30], "similar": [17, 30, 100], "dict": [17, 28, 30, 35, 75, 76, 84], "handl": [17, 84, 95], "hand": 17, "case": [17, 100], "per": [17, 36, 82], "mc": [17, 35, 36], "tree": [17, 35], "trigger": 17, "exclud": [17, 38, 100], "valueerror": 17, "hybrid": 18, "galatict": 18, "plane": 18, "tum": [19, 26], "dnn": [19, 26], "padding_valu": [19, 22], "northeren": 19, "i3particl": 20, "other": [20, 36, 98], "algorithm": 20, "comparison": 20, "quantiti": 21, "select": [22, 40, 82, 98], "queso": 22, "retro": 23, "splinemp": 24, "border": 25, "mctree": [25, 29], "ndarrai": [25, 82], "arrai": [25, 28], "boundari": 25, "volum": 25, "coordin": 25, "particl": [25, 36], "start": [25, 98, 100], "stop": [25, 84], "within": 25, "hard": 25, "i3mctre": 25, "valu": [25, 28, 35, 36, 76, 84], "flatten_nested_dictionari": [27, 28], "serialis": [27, 28], "transpose_list_of_dict": [27, 28], "frame_is_montecarlo": [27, 29], "frame_is_nois": [27, 29], "get_om_keys_and_pulseseri": [27, 29], "is_boost_enum": [27, 30], "is_boost_class": [27, 30], "is_icecube_class": [27, 30], "is_typ": [27, 30], "is_method": [27, 30], "break_cyclic_recurs": [27, 30], "get_member_vari": [27, 30], "cast_object_to_pure_python": [27, 30], "cast_pulse_series_to_pure_python": [27, 30], "manipul": 28, "obj": [28, 30], "parent_kei": 28, "flatten": 28, "nest": 28, "dictionari": [28, 29, 30, 35, 75, 76], "non": [28, 30, 35, 36], "exampl": [28, 40, 100], "d": [28, 98], "b": 28, "c": [28, 100], "2": [28, 75, 76, 100], "a__b": 28, "applic": 28, "combin": 28, "parent": 28, "__": [28, 30], "concaten": 28, "nester": 28, "json": 28, "therefor": 28, "we": [28, 30, 40, 98, 100], "element": [28, 30], "outer": 28, "abl": [28, 100], "de": 28, "transpos": 28, "check": [29, 30, 35, 36, 84, 93, 94, 98, 100], "whether": [29, 30, 35, 36, 93, 94], "mont": 29, "carlo": 29, "simul": 29, "bool": [29, 30, 35, 36, 40, 75, 82, 84, 93, 94, 95], "pulseseri": 29, "calibr": [29, 30], "indici": [29, 40], "gcd_dict": [29, 30], "p": [29, 35], "om": [29, 30], "dataclass": 29, "i3calibr": 29, "indicesfor": 29, "boost": 30, "enum": 30, "fn": 30, "ensur": [30, 39, 95, 98, 100], "isn": 30, "return_discard": 30, "valid": [30, 40, 84], "ignor": 30, "mangl": 30, "take": [30, 35, 98], "mainli": 30, "cannot": 30, "trivial": 30, "doe": 30, "try": 30, "length": 30, "equival": 30, "its": 30, "like": [30, 98], "otherwis": 30, "itself": 30, "deem": 30, "wai": [30, 40, 98, 100], "represent": 30, "optic": 30, "found": 30, "parquetdataconvert": [31, 32], "sqlitedataconvert": [34, 35, 100], "construct_datafram": [34, 35], "is_pulse_map": [34, 35], "is_mc_tre": [34, 35], "database_exist": [34, 36], "database_table_exist": [34, 36], "run_sql_cod": [34, 36], "save_to_sql": [34, 36], "attach_index": [34, 36], "create_t": [34, 36], "create_table_and_save_to_sql": [34, 36], "db": 35, "databas": [35, 36, 38, 75, 82, 100], "max_table_s": 35, "maximum": [35, 84], "row": [35, 36], "given": [35, 82, 84], "exce": 35, "limit": 35, "creat": [35, 36, 98, 100], "any_pulsemap_is_non_empti": 35, "data_dict": 35, "empti": 35, "retriev": 35, "splitinicepuls": 35, "least": [35, 98, 100], "true": [35, 36, 75, 82], "becaus": [35, 39], "instead": 35, "alwai": 35, "panda": [35, 40, 82], "datafram": [35, 36, 40, 75, 82], "table_nam": [35, 36], "database_path": [36, 75, 82], "df": 36, "must": [36, 82, 98], "alreadi": [36, 100], "attach": 36, "queri": [36, 40], "column": [36, 75, 82], "default_typ": 36, "null": 36, "integer_primary_kei": 36, "event_no": [36, 40, 82], "NOT": 36, "integ": 36, "primari": 36, "Such": 36, "uniqu": [36, 38], "appropri": 36, "expect": [36, 40], "doesn": 36, "parquettosqliteconvert": [37, 38], "pairwise_shuffl": [37, 39], "stringselectionresolv": [37, 40], "parquet_path": 38, "mc_truth_tabl": 38, "excluded_field": 38, "assign": [38, 98], "id": 38, "everi": [38, 100], "field": [38, 76], "One": [38, 76], "choos": 38, "argument": [38, 82, 84], "exclude_field": 38, "database_nam": 38, "convers": [38, 100], "directori": [38, 75, 93], "rng": 39, "relat": [39, 93], "i3_list": [39, 93], "gcd_list": [39, 93], "shuffl": 39, "correpond": 39, "handi": 39, "even": 39, "files_list": 39, "gcd_shuffl": 39, "i3_shuffl": 39, "resolv": 40, "indic": [40, 84, 98], "seed": 40, "use_cach": 40, "datasetconfig": 40, "flexibl": 40, "defin": 40, "below": [40, 76, 82, 98, 100], "show": 40, "involv": 40, "cover": 40, "yml": [40, 84], "test": [40, 94, 98], "50000": 40, "ab": 40, "12": 40, "14": 40, "16": 40, "13": [40, 100], "10000": 40, "compat": 40, "syntax": 40, "mai": [40, 100], "also": 40, "specifi": [40, 76, 100], "fix": 40, "randomli": 40, "20": [40, 95], "graphnet_modul": [41, 42], "convnet": [45, 54], "dynedg": [45, 54], "dynedge_jinst": [45, 54], "dynedge_kaggle_tito": [45, 54], "edg": [45, 60], "node": [45, 60], "graph_definit": [45, 60], "config_updat": [74, 75], "weightfitt": [74, 75, 77, 82], "contourfitt": [74, 75], "read_entri": [74, 76], "plot_2d_contour": [74, 76], "plot_1d_contour": [74, 76], "contour": [75, 76], "config_path": 75, "new_config_path": 75, "dummy_sect": 75, "updat": 75, "temp": 75, "dummi": 75, "section": 75, "header": 75, "configupdat": 75, "programat": 75, "truth_tabl": [75, 82], "statistical_fit": 75, "weight": [75, 82, 100], "fit_weight": [75, 82], "config_outdir": 75, "weight_nam": [75, 82], "pisa_config_dict": 75, "add_to_databas": [75, 82], "flux": 75, "self": 75, "_database_path": 75, "statist": 75, "effect": [75, 98], "account": 75, "systemat": 75, "hypersurfac": 75, "chang": [75, 98], "assumpt": 75, "regard": 75, "fals": [75, 82], "two": 75, "pipeline_path": 75, "post_fix": 75, "model_nam": 75, "include_retro": 75, "fit_1d_contour": 75, "run_nam": 75, "config_dict": 75, "grid_siz": 75, "n_worker": 75, "theta23_minmax": 75, "36": 75, "54": 75, "dm31_minmax": 75, "3": [75, 76, 98, 100], "7": 75, "1d": [75, 76], "float": [75, 76], "fit_2d_contour": 75, "2d": [75, 76], "entri": [76, 84], "content": 76, "contour_data": 76, "xlim": 76, "4": 76, "6": 76, "ylim": 76, "0023799999999999997": 76, "0025499999999999997": 76, "chi2_critical_valu": 76, "width": 76, "height": 76, "path_to_pisa_fit_result": 76, "name_of_my_model_in_fit": 76, "legend": 76, "color": 76, "linestyl": 76, "style": [76, 98], "line": [76, 84], "upper": 76, "axi": 76, "605": 76, "critic": [76, 95], "chi2": 76, "90": 76, "cl": 76, "note": 76, "right": 76, "176": 76, "inch": 76, "388": 76, "706": 76, "abov": [76, 82, 100], "352": 76, "uniform": [77, 82], "bjoernlow": [77, 82], "produc": 82, "public": 82, "uniformweightfitt": 82, "bin": 82, "kwarg": [82, 95], "privat": 82, "_fit_weight": 82, "sql": 82, "desir": [82, 93], "space": 82, "np": 82, "log10": 82, "happen": 82, "addit": 82, "pass": [82, 98], "distribut": 82, "x_low": 82, "wherea": 82, "curv": 82, "base_config": [83, 85], "dataset_config": [83, 85], "model_config": [83, 85], "training_config": [83, 85], "argumentpars": [83, 84], "is_gcd_fil": [83, 93], "is_i3_fil": [83, 93], "has_extens": [83, 93], "find_i3_fil": [83, 93], "has_icecube_packag": [83, 94], "has_torch_packag": [83, 94], "has_pisa_packag": [83, 94], "requires_icecub": [83, 94], "repeatfilt": [83, 95], "consist": [84, 95, 98], "cli": 84, "present": [84, 93, 94], "pop_default": 84, "remov": 84, "usag": 84, "descript": 84, "command": [84, 100], "standard_argu": 84, "size": 84, "128": 84, "help": [84, 98], "home": [84, 100], "runner": 84, "local": 84, "lib": [84, 100], "python3": 84, "training_example_data_sqlit": 84, "earli": 84, "patienc": 84, "epoch": 84, "loss": 84, "after": 84, "gpu": [84, 100], "narg": 84, "max": 84, "50": 84, "example_energy_reconstruction_model": 84, "num": 84, "fetch": 84, "with_standard_argu": 84, "arg": [84, 95], "add": [84, 98, 100], "overwritten": 84, "system": [93, 100], "filenam": 93, "dir": 93, "search": 93, "test_funct": 94, "filter": 95, "out": [95, 98, 100], "repeat": 95, "messag": 95, "nb_repeats_allow": 95, "record": 95, "print": 95, "logrecord": 95, "class_nam": 95, "log_fold": 95, "clear": 95, "intuit": 95, "composit": 95, "rather": 95, "loggeradapt": 95, "chosen": 95, "avoid": [95, 98], "clash": 95, "pytorch_lightn": 95, "lightningmodul": 95, "setlevel": 95, "deleg": 95, "msg": 95, "error": [95, 98], "warn": 95, "info": [95, 100], "debug": 95, "warning_onc": 95, "exactli": 95, "onc": 95, "handler": 95, "file_handl": 95, "filehandl": 95, "stream_handl": 95, "streamhandl": 95, "api": 97, "To": [98, 100], "sure": [98, 100], "smooth": 98, "guidelin": 98, "guid": 98, "encourag": 98, "contributor": 98, "discuss": 98, "bug": 98, "anyth": 98, "you": [98, 100], "place": 98, "describ": 98, "altern": 98, "yourself": 98, "ownership": 98, "particular": 98, "activ": [98, 100], "transpar": 98, "prioriti": 98, "situat": 98, "lot": 98, "effort": 98, "go": 98, "turn": 98, "outsid": 98, "scope": 98, "solut": 98, "better": 98, "fork": 98, "repo": 98, "dedic": 98, "branch": [98, 100], "your": [98, 100], "repositori": 98, "graphdefinit": 98, "euclidean": 98, "definit": 98, "own": [98, 100], "team": 98, "accept": 98, "autom": 98, "review": 98, "pep8": 98, "docstr": 98, "googl": 98, "hint": 98, "clean": [98, 100], "see": [98, 100], "version": [98, 100], "8": [98, 100], "adher": 98, "pep": 98, "pylint": 98, "flake8": 98, "black": 98, "well": 98, "recommend": [98, 100], "mypi": 98, "pydocstyl": 98, "docformatt": 98, "commit": 98, "hook": 98, "instal": 98, "come": 98, "tag": [98, 100], "pip": [98, 100], "Then": 98, "everytim": 98, "pep257": 98, "static": 98, "concept": 98, "http": 98, "ljvmiranda921": 98, "io": 98, "notebook": 98, "2018": 98, "06": 98, "21": 98, "precommit": 98, "environ": 100, "virtual": 100, "anaconda": 100, "prove": 100, "instruct": 100, "setup": 100, "want": 100, "part": 100, "In": 100, "runtim": 100, "achiev": 100, "bash": 100, "shell": 100, "eval": 100, "cvmf": 100, "opensciencegrid": 100, "org": 100, "py3": 100, "v4": 100, "sh": 100, "rhel_7_x86_64": 100, "metaproject": 100, "v1": 100, "env": 100, "alia": 100, "script": 100, "With": 100, "now": 100, "light": 100, "extra": 100, "pytorch": 100, "geometr": 100, "just": 100, "won": 100, "later": 100, "don": 100, "r": 100, "torch_cpu": 100, "txt": 100, "cpu": 100, "torch_gpu": 100, "prefer": 100, "unix": 100, "git": 100, "clone": 100, "github": 100, "com": 100, "usernam": 100, "cd": 100, "conda": 100, "gcc_linux": 100, "64": 100, "gxx_linux": 100, "libgcc": 100, "cudatoolkit": 100, "11": 100, "forg": 100, "torch_maco": 100, "On": 100, "maco": 100, "box": 100, "compil": 100, "gcc": 100, "date": 100, "possibli": 100, "cuda": 100, "toolkit": 100, "recent": 100, "omit": 100, "newer": 100, "export": 100, "ld_library_path": 100, "anaconda3": 100, "miniconda3": 100, "bashrc": 100, "librari": 100, "access": 100, "so": 100, "re": 100, "intend": 100, "consid": 100, "rm": 100, "asogaard": 100, "latest": 100, "dc423315742c": 100, "01_icetrai": 100, "01_convert_i3_fil": 100, "py": 100, "2023": 100, "01": 100, "24": 100, "41": 100, "27": 100, "__init__": 100, "write": 100, "graphnet_20230124": 100, "134127": 100, "46": 100, "root": 100, "convert_i3_fil": 100, "ic86": 100, "thread": 100, "100": 100, "00": 100, "79": 100, "42": 100, "26": 100, "413": 100, "88it": 100, "specialis": 100, "ones": 100, "push": 100, "vx": 100}, "objects": {"": [[1, 0, 0, "-", "graphnet"]], "graphnet": [[2, 0, 0, "-", "constants"], [3, 0, 0, "-", "data"], [41, 0, 0, "-", "deployment"], [74, 0, 0, "-", "pisa"], [77, 0, 0, "-", "training"], [83, 0, 0, "-", "utilities"]], "graphnet.data": [[4, 0, 0, "-", "constants"], [5, 0, 0, "-", "dataconverter"], [14, 0, 0, "-", "extractors"], [31, 0, 0, "-", "parquet"], [34, 0, 0, "-", "sqlite"], [37, 0, 0, "-", "utilities"]], "graphnet.data.constants": [[4, 1, 1, "", "FEATURES"], [4, 1, 1, "", "TRUTH"]], "graphnet.data.constants.FEATURES": [[4, 2, 1, "", "DEEPCORE"], [4, 2, 1, "", "ICECUBE86"], [4, 2, 1, "", "KAGGLE"], [4, 2, 1, "", "PROMETHEUS"], [4, 2, 1, "", "UPGRADE"]], "graphnet.data.constants.TRUTH": [[4, 2, 1, "", "DEEPCORE"], [4, 2, 1, "", "ICECUBE86"], [4, 2, 1, "", "KAGGLE"], [4, 2, 1, "", "PROMETHEUS"], [4, 2, 1, "", "UPGRADE"]], "graphnet.data.dataconverter": [[5, 1, 1, "", "DataConverter"], [5, 1, 1, "", "FileSet"], [5, 5, 1, "", "cache_output_files"], [5, 5, 1, "", "init_global_index"]], "graphnet.data.dataconverter.DataConverter": [[5, 3, 1, "", "execute"], [5, 4, 1, "", "file_suffix"], [5, 3, 1, "", "get_map_function"], [5, 3, 1, "", "merge_files"], [5, 3, 1, "", "save_data"]], "graphnet.data.dataconverter.FileSet": [[5, 2, 1, "", "gcd_file"], [5, 2, 1, "", "i3_file"]], "graphnet.data.extractors": [[15, 0, 0, "-", "i3extractor"], [16, 0, 0, "-", "i3featureextractor"], [17, 0, 0, "-", "i3genericextractor"], [18, 0, 0, "-", "i3hybridrecoextractor"], [19, 0, 0, "-", "i3ntmuonlabelsextractor"], [20, 0, 0, "-", "i3particleextractor"], [21, 0, 0, "-", "i3pisaextractor"], [22, 0, 0, "-", "i3quesoextractor"], [23, 0, 0, "-", "i3retroextractor"], [24, 0, 0, "-", "i3splinempeextractor"], [25, 0, 0, "-", "i3truthextractor"], [26, 0, 0, "-", "i3tumextractor"], [27, 0, 0, "-", "utilities"]], "graphnet.data.extractors.i3extractor": [[15, 1, 1, "", "I3Extractor"], [15, 1, 1, "", "I3ExtractorCollection"]], "graphnet.data.extractors.i3extractor.I3Extractor": [[15, 4, 1, "", "name"], [15, 3, 1, "", "set_files"]], "graphnet.data.extractors.i3extractor.I3ExtractorCollection": [[15, 3, 1, "", "set_files"]], "graphnet.data.extractors.i3featureextractor": [[16, 1, 1, "", "I3FeatureExtractor"], [16, 1, 1, "", "I3FeatureExtractorIceCube86"], [16, 1, 1, "", "I3FeatureExtractorIceCubeDeepCore"], [16, 1, 1, "", "I3FeatureExtractorIceCubeUpgrade"], [16, 1, 1, "", "I3PulseNoiseTruthFlagIceCubeUpgrade"]], "graphnet.data.extractors.i3genericextractor": [[17, 1, 1, "", "I3GenericExtractor"]], "graphnet.data.extractors.i3hybridrecoextractor": [[18, 1, 1, "", "I3GalacticPlaneHybridRecoExtractor"]], "graphnet.data.extractors.i3ntmuonlabelsextractor": [[19, 1, 1, "", "I3NTMuonLabelExtractor"]], "graphnet.data.extractors.i3particleextractor": [[20, 1, 1, "", "I3ParticleExtractor"]], "graphnet.data.extractors.i3pisaextractor": [[21, 1, 1, "", "I3PISAExtractor"]], "graphnet.data.extractors.i3quesoextractor": [[22, 1, 1, "", "I3QUESOExtractor"]], "graphnet.data.extractors.i3retroextractor": [[23, 1, 1, "", "I3RetroExtractor"]], "graphnet.data.extractors.i3splinempeextractor": [[24, 1, 1, "", "I3SplineMPEICExtractor"]], "graphnet.data.extractors.i3truthextractor": [[25, 1, 1, "", "I3TruthExtractor"]], "graphnet.data.extractors.i3tumextractor": [[26, 1, 1, "", "I3TUMExtractor"]], "graphnet.data.extractors.utilities": [[28, 0, 0, "-", "collections"], [29, 0, 0, "-", "frames"], [30, 0, 0, "-", "types"]], "graphnet.data.extractors.utilities.collections": [[28, 5, 1, "", "flatten_nested_dictionary"], [28, 5, 1, "", "serialise"], [28, 5, 1, "", "transpose_list_of_dicts"]], "graphnet.data.extractors.utilities.frames": [[29, 5, 1, "", "frame_is_montecarlo"], [29, 5, 1, "", "frame_is_noise"], [29, 5, 1, "", "get_om_keys_and_pulseseries"]], "graphnet.data.extractors.utilities.types": [[30, 5, 1, "", "break_cyclic_recursion"], [30, 5, 1, "", "cast_object_to_pure_python"], [30, 5, 1, "", "cast_pulse_series_to_pure_python"], [30, 5, 1, "", "get_member_variables"], [30, 5, 1, "", "is_boost_class"], [30, 5, 1, "", "is_boost_enum"], [30, 5, 1, "", "is_icecube_class"], [30, 5, 1, "", "is_method"], [30, 5, 1, "", "is_type"]], "graphnet.data.parquet": [[32, 0, 0, "-", "parquet_dataconverter"]], "graphnet.data.parquet.parquet_dataconverter": [[32, 1, 1, "", "ParquetDataConverter"]], "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter": [[32, 2, 1, "", "file_suffix"], [32, 3, 1, "", "merge_files"], [32, 3, 1, "", "save_data"]], "graphnet.data.sqlite": [[35, 0, 0, "-", "sqlite_dataconverter"], [36, 0, 0, "-", "sqlite_utilities"]], "graphnet.data.sqlite.sqlite_dataconverter": [[35, 1, 1, "", "SQLiteDataConverter"], [35, 5, 1, "", "construct_dataframe"], [35, 5, 1, "", "is_mc_tree"], [35, 5, 1, "", "is_pulse_map"]], "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter": [[35, 3, 1, "", "any_pulsemap_is_non_empty"], [35, 2, 1, "", "file_suffix"], [35, 3, 1, "", "merge_files"], [35, 3, 1, "", "save_data"]], "graphnet.data.sqlite.sqlite_utilities": [[36, 5, 1, "", "attach_index"], [36, 5, 1, "", "create_table"], [36, 5, 1, "", "create_table_and_save_to_sql"], [36, 5, 1, "", "database_exists"], [36, 5, 1, "", "database_table_exists"], [36, 5, 1, "", "run_sql_code"], [36, 5, 1, "", "save_to_sql"]], "graphnet.data.utilities": [[38, 0, 0, "-", "parquet_to_sqlite"], [39, 0, 0, "-", "random"], [40, 0, 0, "-", "string_selection_resolver"]], "graphnet.data.utilities.parquet_to_sqlite": [[38, 1, 1, "", "ParquetToSQLiteConverter"]], "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter": [[38, 3, 1, "", "run"]], "graphnet.data.utilities.random": [[39, 5, 1, "", "pairwise_shuffle"]], "graphnet.data.utilities.string_selection_resolver": [[40, 1, 1, "", "StringSelectionResolver"]], "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver": [[40, 3, 1, "", "resolve"]], "graphnet.pisa": [[75, 0, 0, "-", "fitting"], [76, 0, 0, "-", "plotting"]], "graphnet.pisa.fitting": [[75, 1, 1, "", "ContourFitter"], [75, 1, 1, "", "WeightFitter"], [75, 5, 1, "", "config_updater"]], "graphnet.pisa.fitting.ContourFitter": [[75, 3, 1, "", "fit_1d_contour"], [75, 3, 1, "", "fit_2d_contour"]], "graphnet.pisa.fitting.WeightFitter": [[75, 3, 1, "", "fit_weights"]], "graphnet.pisa.plotting": [[76, 5, 1, "", "plot_1D_contour"], [76, 5, 1, "", "plot_2D_contour"], [76, 5, 1, "", "read_entry"]], "graphnet.training": [[82, 0, 0, "-", "weight_fitting"]], "graphnet.training.weight_fitting": [[82, 1, 1, "", "BjoernLow"], [82, 1, 1, "", "Uniform"], [82, 1, 1, "", "WeightFitter"]], "graphnet.training.weight_fitting.WeightFitter": [[82, 3, 1, "", "fit"]], "graphnet.utilities": [[84, 0, 0, "-", "argparse"], [92, 0, 0, "-", "decorators"], [93, 0, 0, "-", "filesys"], [94, 0, 0, "-", "imports"], [95, 0, 0, "-", "logging"]], "graphnet.utilities.argparse": [[84, 1, 1, "", "ArgumentParser"], [84, 1, 1, "", "Options"]], "graphnet.utilities.argparse.ArgumentParser": [[84, 2, 1, "", "standard_arguments"], [84, 3, 1, "", "with_standard_arguments"]], "graphnet.utilities.argparse.Options": [[84, 3, 1, "", "contains"], [84, 3, 1, "", "pop_default"]], "graphnet.utilities.filesys": [[93, 5, 1, "", "find_i3_files"], [93, 5, 1, "", "has_extension"], [93, 5, 1, "", "is_gcd_file"], [93, 5, 1, "", "is_i3_file"]], "graphnet.utilities.imports": [[94, 5, 1, "", "has_icecube_package"], [94, 5, 1, "", "has_pisa_package"], [94, 5, 1, "", "has_torch_package"], [94, 5, 1, "", "requires_icecube"]], "graphnet.utilities.logging": [[95, 1, 1, "", "Logger"], [95, 1, 1, "", "RepeatFilter"]], "graphnet.utilities.logging.Logger": [[95, 3, 1, "", "critical"], [95, 3, 1, "", "debug"], [95, 3, 1, "", "error"], [95, 4, 1, "", "file_handlers"], [95, 4, 1, "", "handlers"], [95, 3, 1, "", "info"], [95, 3, 1, "", "setLevel"], [95, 4, 1, "", "stream_handlers"], [95, 3, 1, "", "warning"], [95, 3, 1, "", "warning_once"]], "graphnet.utilities.logging.RepeatFilter": [[95, 3, 1, "", "filter"], [95, 2, 1, "", "nb_repeats_allowed"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:attribute", "3": "py:method", "4": "py:property", "5": "py:function"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "attribute", "Python attribute"], "3": ["py", "method", "Python method"], "4": ["py", "property", "Python property"], "5": ["py", "function", "Python function"]}, "titleterms": {"about": [0, 99], "impact": [0, 99], "usag": [0, 99], "acknowledg": [0, 99], "api": 1, "constant": [2, 4], "data": 3, "dataconvert": 5, "dataload": 6, "dataset": [7, 8], "parquet": [9, 31], "parquet_dataset": 10, "sqlite": [11, 34], "sqlite_dataset": 12, "sqlite_dataset_perturb": 13, "extractor": 14, "i3extractor": 15, "i3featureextractor": 16, "i3genericextractor": 17, "i3hybridrecoextractor": 18, "i3ntmuonlabelsextractor": 19, "i3particleextractor": 20, "i3pisaextractor": 21, "i3quesoextractor": 22, "i3retroextractor": 23, "i3splinempeextractor": 24, "i3truthextractor": 25, "i3tumextractor": 26, "util": [27, 37, 73, 81, 83], "collect": 28, "frame": 29, "type": 30, "parquet_dataconvert": 32, "pipelin": 33, "sqlite_dataconvert": 35, "sqlite_util": 36, "parquet_to_sqlit": 38, "random": 39, "string_selection_resolv": 40, "deploy": [41, 43], "i3modul": 42, "graphnet_modul": 44, "model": [45, 67], "coarsen": 46, "compon": 47, "layer": 48, "pool": 49, "detector": [50, 51], "icecub": 52, "prometheu": 53, "gnn": [54, 59], "convnet": 55, "dynedg": 56, "dynedge_jinst": 57, "dynedge_kaggle_tito": 58, "graph": [60, 64], "edg": [61, 62], "graph_definit": 63, "node": [65, 66], "standard_model": 68, "task": [69, 72], "classif": 70, "reconstruct": 71, "pisa": 74, "fit": 75, "plot": 76, "train": 77, "callback": 78, "label": 79, "loss_funct": 80, "weight_fit": 82, "argpars": 84, "config": 85, "base_config": 86, "configur": 87, "dataset_config": 88, "model_config": 89, "pars": 90, "training_config": 91, "decor": 92, "filesi": 93, "import": 94, "log": 95, "math": 96, "src": 97, "contribut": 98, "github": 98, "issu": 98, "pull": 98, "request": 98, "convent": 98, "code": 98, "qualiti": 98, "instal": 100, "icetrai": 100, "stand": 100, "alon": 100, "run": 100, "docker": 100}, "envversion": {"sphinx.domains.c": 3, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 9, "sphinx.domains.index": 1, "sphinx.domains.javascript": 3, "sphinx.domains.math": 2, "sphinx.domains.python": 4, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx.ext.todo": 2, "sphinx.ext.viewcode": 1, "sphinx": 60}, "alltitles": {"About": [[0, "about"], [99, "about"]], "Impact": [[0, "impact"], [99, "impact"]], "Usage": [[0, "usage"], [99, "usage"]], "Acknowledgements": [[0, "acknowledgements"], [99, "acknowledgements"]], "API": [[1, "module-graphnet"]], "constants": [[2, "module-graphnet.constants"], [4, "module-graphnet.data.constants"]], "data": [[3, "module-graphnet.data"]], "dataconverter": [[5, "module-graphnet.data.dataconverter"]], "dataloader": [[6, "dataloader"]], "dataset": [[7, "dataset"], [8, "dataset"]], "parquet": [[9, "parquet"], [31, "module-graphnet.data.parquet"]], "parquet_dataset": [[10, "parquet-dataset"]], "sqlite": [[11, "sqlite"], [34, "module-graphnet.data.sqlite"]], "sqlite_dataset": [[12, "sqlite-dataset"]], "sqlite_dataset_perturbed": [[13, "sqlite-dataset-perturbed"]], "extractors": [[14, "module-graphnet.data.extractors"]], "i3extractor": [[15, "module-graphnet.data.extractors.i3extractor"]], "i3featureextractor": [[16, "module-graphnet.data.extractors.i3featureextractor"]], "i3genericextractor": [[17, "module-graphnet.data.extractors.i3genericextractor"]], "i3hybridrecoextractor": [[18, "module-graphnet.data.extractors.i3hybridrecoextractor"]], "i3ntmuonlabelsextractor": [[19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"]], "i3particleextractor": [[20, "module-graphnet.data.extractors.i3particleextractor"]], "i3pisaextractor": [[21, "module-graphnet.data.extractors.i3pisaextractor"]], "i3quesoextractor": [[22, "module-graphnet.data.extractors.i3quesoextractor"]], "i3retroextractor": [[23, "module-graphnet.data.extractors.i3retroextractor"]], "i3splinempeextractor": [[24, "module-graphnet.data.extractors.i3splinempeextractor"]], "i3truthextractor": [[25, "module-graphnet.data.extractors.i3truthextractor"]], "i3tumextractor": [[26, "module-graphnet.data.extractors.i3tumextractor"]], "utilities": [[27, "module-graphnet.data.extractors.utilities"], [37, "module-graphnet.data.utilities"], [83, "module-graphnet.utilities"]], "collections": [[28, "module-graphnet.data.extractors.utilities.collections"]], "frames": [[29, "module-graphnet.data.extractors.utilities.frames"]], "types": [[30, "module-graphnet.data.extractors.utilities.types"]], "parquet_dataconverter": [[32, "module-graphnet.data.parquet.parquet_dataconverter"]], "pipeline": [[33, "pipeline"]], "sqlite_dataconverter": [[35, "module-graphnet.data.sqlite.sqlite_dataconverter"]], "sqlite_utilities": [[36, "module-graphnet.data.sqlite.sqlite_utilities"]], "parquet_to_sqlite": [[38, "module-graphnet.data.utilities.parquet_to_sqlite"]], "random": [[39, "module-graphnet.data.utilities.random"]], "string_selection_resolver": [[40, "module-graphnet.data.utilities.string_selection_resolver"]], "deployment": [[41, "module-graphnet.deployment"]], "i3modules": [[42, "i3modules"]], "deployer": [[43, "deployer"]], "graphnet_module": [[44, "graphnet-module"]], "models": [[45, "models"]], "coarsening": [[46, "coarsening"]], "components": [[47, "components"]], "layers": [[48, "layers"]], "pool": [[49, "pool"]], "detector": [[50, "detector"], [51, "detector"]], "icecube": [[52, "icecube"]], "prometheus": [[53, "prometheus"]], "gnn": [[54, "gnn"], [59, "gnn"]], "convnet": [[55, "convnet"]], "dynedge": [[56, "dynedge"]], "dynedge_jinst": [[57, "dynedge-jinst"]], "dynedge_kaggle_tito": [[58, "dynedge-kaggle-tito"]], "graphs": [[60, "graphs"], [64, "graphs"]], "edges": [[61, "edges"], [62, "edges"]], "graph_definition": [[63, "graph-definition"]], "nodes": [[65, "nodes"], [66, "nodes"]], "model": [[67, "model"]], "standard_model": [[68, "standard-model"]], "task": [[69, "task"], [72, "task"]], "classification": [[70, "classification"]], "reconstruction": [[71, "reconstruction"]], "utils": [[73, "utils"], [81, "utils"]], "pisa": [[74, "module-graphnet.pisa"]], "fitting": [[75, "module-graphnet.pisa.fitting"]], "plotting": [[76, "module-graphnet.pisa.plotting"]], "training": [[77, "module-graphnet.training"]], "callbacks": [[78, "callbacks"]], "labels": [[79, "labels"]], "loss_functions": [[80, "loss-functions"]], "weight_fitting": [[82, "module-graphnet.training.weight_fitting"]], "argparse": [[84, "module-graphnet.utilities.argparse"]], "config": [[85, "config"]], "base_config": [[86, "base-config"]], "configurable": [[87, "configurable"]], "dataset_config": [[88, "dataset-config"]], "model_config": [[89, "model-config"]], "parsing": [[90, "parsing"]], "training_config": [[91, "training-config"]], "decorators": [[92, "module-graphnet.utilities.decorators"]], "filesys": [[93, "module-graphnet.utilities.filesys"]], "imports": [[94, "module-graphnet.utilities.imports"]], "logging": [[95, "module-graphnet.utilities.logging"]], "maths": [[96, "maths"]], "src": [[97, "src"]], "Contribute": [[98, "contribute"]], "GitHub issues": [[98, "github-issues"]], "Pull requests": [[98, "pull-requests"]], "Conventions": [[98, "conventions"]], "Code quality": [[98, "code-quality"]], "Install": [[100, "install"]], "Installing with IceTray": [[100, "installing-with-icetray"]], "Installing stand-alone": [[100, "installing-stand-alone"]], "Running in Docker": [[100, "running-in-docker"]]}, "indexentries": {"graphnet": [[1, "module-graphnet"]], "module": [[1, "module-graphnet"], [2, "module-graphnet.constants"], [3, "module-graphnet.data"], [4, "module-graphnet.data.constants"], [5, "module-graphnet.data.dataconverter"], [14, "module-graphnet.data.extractors"], [15, "module-graphnet.data.extractors.i3extractor"], [16, "module-graphnet.data.extractors.i3featureextractor"], [17, "module-graphnet.data.extractors.i3genericextractor"], [18, "module-graphnet.data.extractors.i3hybridrecoextractor"], [19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"], [20, "module-graphnet.data.extractors.i3particleextractor"], [21, "module-graphnet.data.extractors.i3pisaextractor"], [22, "module-graphnet.data.extractors.i3quesoextractor"], [23, "module-graphnet.data.extractors.i3retroextractor"], [24, "module-graphnet.data.extractors.i3splinempeextractor"], [25, "module-graphnet.data.extractors.i3truthextractor"], [26, "module-graphnet.data.extractors.i3tumextractor"], [27, "module-graphnet.data.extractors.utilities"], [28, "module-graphnet.data.extractors.utilities.collections"], [29, "module-graphnet.data.extractors.utilities.frames"], [30, "module-graphnet.data.extractors.utilities.types"], [31, "module-graphnet.data.parquet"], [32, "module-graphnet.data.parquet.parquet_dataconverter"], [34, "module-graphnet.data.sqlite"], [35, "module-graphnet.data.sqlite.sqlite_dataconverter"], [36, "module-graphnet.data.sqlite.sqlite_utilities"], [37, "module-graphnet.data.utilities"], [38, "module-graphnet.data.utilities.parquet_to_sqlite"], [39, "module-graphnet.data.utilities.random"], [40, "module-graphnet.data.utilities.string_selection_resolver"], [41, "module-graphnet.deployment"], [74, "module-graphnet.pisa"], [75, "module-graphnet.pisa.fitting"], [76, "module-graphnet.pisa.plotting"], [77, "module-graphnet.training"], [82, "module-graphnet.training.weight_fitting"], [83, "module-graphnet.utilities"], [84, "module-graphnet.utilities.argparse"], [92, "module-graphnet.utilities.decorators"], [93, "module-graphnet.utilities.filesys"], [94, "module-graphnet.utilities.imports"], [95, "module-graphnet.utilities.logging"]], "graphnet.constants": [[2, "module-graphnet.constants"]], "graphnet.data": [[3, "module-graphnet.data"]], "deepcore (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.DEEPCORE"]], "deepcore (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.DEEPCORE"]], "features (class in graphnet.data.constants)": [[4, "graphnet.data.constants.FEATURES"]], "icecube86 (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.ICECUBE86"]], "icecube86 (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.ICECUBE86"]], "kaggle (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.KAGGLE"]], "kaggle (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.KAGGLE"]], "prometheus (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.PROMETHEUS"]], "prometheus (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.PROMETHEUS"]], "truth (class in graphnet.data.constants)": [[4, "graphnet.data.constants.TRUTH"]], "upgrade (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.UPGRADE"]], "upgrade (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.UPGRADE"]], "graphnet.data.constants": [[4, "module-graphnet.data.constants"]], "dataconverter (class in graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.DataConverter"]], "fileset (class in graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.FileSet"]], "cache_output_files() (in module graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.cache_output_files"]], "execute() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.execute"]], "file_suffix (graphnet.data.dataconverter.dataconverter property)": [[5, "graphnet.data.dataconverter.DataConverter.file_suffix"]], "gcd_file (graphnet.data.dataconverter.fileset attribute)": [[5, "graphnet.data.dataconverter.FileSet.gcd_file"]], "get_map_function() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.get_map_function"]], "graphnet.data.dataconverter": [[5, "module-graphnet.data.dataconverter"]], "i3_file (graphnet.data.dataconverter.fileset attribute)": [[5, "graphnet.data.dataconverter.FileSet.i3_file"]], "init_global_index() (in module graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.init_global_index"]], "merge_files() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.merge_files"]], "save_data() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.save_data"]], "graphnet.data.extractors": [[14, "module-graphnet.data.extractors"]], "i3extractor (class in graphnet.data.extractors.i3extractor)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor"]], "i3extractorcollection (class in graphnet.data.extractors.i3extractor)": [[15, "graphnet.data.extractors.i3extractor.I3ExtractorCollection"]], "graphnet.data.extractors.i3extractor": [[15, "module-graphnet.data.extractors.i3extractor"]], "name (graphnet.data.extractors.i3extractor.i3extractor property)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor.name"]], "set_files() (graphnet.data.extractors.i3extractor.i3extractor method)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor.set_files"]], "set_files() (graphnet.data.extractors.i3extractor.i3extractorcollection method)": [[15, "graphnet.data.extractors.i3extractor.I3ExtractorCollection.set_files"]], "i3featureextractor (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"]], "i3featureextractoricecube86 (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCube86"]], "i3featureextractoricecubedeepcore (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCubeDeepCore"]], "i3featureextractoricecubeupgrade (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCubeUpgrade"]], "i3pulsenoisetruthflagicecubeupgrade (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3PulseNoiseTruthFlagIceCubeUpgrade"]], "graphnet.data.extractors.i3featureextractor": [[16, "module-graphnet.data.extractors.i3featureextractor"]], "i3genericextractor (class in graphnet.data.extractors.i3genericextractor)": [[17, "graphnet.data.extractors.i3genericextractor.I3GenericExtractor"]], "graphnet.data.extractors.i3genericextractor": [[17, "module-graphnet.data.extractors.i3genericextractor"]], "i3galacticplanehybridrecoextractor (class in graphnet.data.extractors.i3hybridrecoextractor)": [[18, "graphnet.data.extractors.i3hybridrecoextractor.I3GalacticPlaneHybridRecoExtractor"]], "graphnet.data.extractors.i3hybridrecoextractor": [[18, "module-graphnet.data.extractors.i3hybridrecoextractor"]], "i3ntmuonlabelextractor (class in graphnet.data.extractors.i3ntmuonlabelsextractor)": [[19, "graphnet.data.extractors.i3ntmuonlabelsextractor.I3NTMuonLabelExtractor"]], "graphnet.data.extractors.i3ntmuonlabelsextractor": [[19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"]], "i3particleextractor (class in graphnet.data.extractors.i3particleextractor)": [[20, "graphnet.data.extractors.i3particleextractor.I3ParticleExtractor"]], "graphnet.data.extractors.i3particleextractor": [[20, "module-graphnet.data.extractors.i3particleextractor"]], "i3pisaextractor (class in graphnet.data.extractors.i3pisaextractor)": [[21, "graphnet.data.extractors.i3pisaextractor.I3PISAExtractor"]], "graphnet.data.extractors.i3pisaextractor": [[21, "module-graphnet.data.extractors.i3pisaextractor"]], "i3quesoextractor (class in graphnet.data.extractors.i3quesoextractor)": [[22, "graphnet.data.extractors.i3quesoextractor.I3QUESOExtractor"]], "graphnet.data.extractors.i3quesoextractor": [[22, "module-graphnet.data.extractors.i3quesoextractor"]], "i3retroextractor (class in graphnet.data.extractors.i3retroextractor)": [[23, "graphnet.data.extractors.i3retroextractor.I3RetroExtractor"]], "graphnet.data.extractors.i3retroextractor": [[23, "module-graphnet.data.extractors.i3retroextractor"]], "i3splinempeicextractor (class in graphnet.data.extractors.i3splinempeextractor)": [[24, "graphnet.data.extractors.i3splinempeextractor.I3SplineMPEICExtractor"]], "graphnet.data.extractors.i3splinempeextractor": [[24, "module-graphnet.data.extractors.i3splinempeextractor"]], "i3truthextractor (class in graphnet.data.extractors.i3truthextractor)": [[25, "graphnet.data.extractors.i3truthextractor.I3TruthExtractor"]], "graphnet.data.extractors.i3truthextractor": [[25, "module-graphnet.data.extractors.i3truthextractor"]], "i3tumextractor (class in graphnet.data.extractors.i3tumextractor)": [[26, "graphnet.data.extractors.i3tumextractor.I3TUMExtractor"]], "graphnet.data.extractors.i3tumextractor": [[26, "module-graphnet.data.extractors.i3tumextractor"]], "graphnet.data.extractors.utilities": [[27, "module-graphnet.data.extractors.utilities"]], "flatten_nested_dictionary() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.flatten_nested_dictionary"]], "graphnet.data.extractors.utilities.collections": [[28, "module-graphnet.data.extractors.utilities.collections"]], "serialise() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.serialise"]], "transpose_list_of_dicts() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.transpose_list_of_dicts"]], "frame_is_montecarlo() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.frame_is_montecarlo"]], "frame_is_noise() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.frame_is_noise"]], "get_om_keys_and_pulseseries() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.get_om_keys_and_pulseseries"]], "graphnet.data.extractors.utilities.frames": [[29, "module-graphnet.data.extractors.utilities.frames"]], "break_cyclic_recursion() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.break_cyclic_recursion"]], "cast_object_to_pure_python() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.cast_object_to_pure_python"]], "cast_pulse_series_to_pure_python() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.cast_pulse_series_to_pure_python"]], "get_member_variables() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.get_member_variables"]], "graphnet.data.extractors.utilities.types": [[30, "module-graphnet.data.extractors.utilities.types"]], "is_boost_class() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_boost_class"]], "is_boost_enum() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_boost_enum"]], "is_icecube_class() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_icecube_class"]], "is_method() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_method"]], "is_type() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_type"]], "graphnet.data.parquet": [[31, "module-graphnet.data.parquet"]], "parquetdataconverter (class in graphnet.data.parquet.parquet_dataconverter)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter"]], "file_suffix (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter attribute)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.file_suffix"]], "graphnet.data.parquet.parquet_dataconverter": [[32, "module-graphnet.data.parquet.parquet_dataconverter"]], "merge_files() (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter method)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.merge_files"]], "save_data() (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter method)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.save_data"]], "graphnet.data.sqlite": [[34, "module-graphnet.data.sqlite"]], "sqlitedataconverter (class in graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter"]], "any_pulsemap_is_non_empty() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.any_pulsemap_is_non_empty"]], "construct_dataframe() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.construct_dataframe"]], "file_suffix (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter attribute)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.file_suffix"]], "graphnet.data.sqlite.sqlite_dataconverter": [[35, "module-graphnet.data.sqlite.sqlite_dataconverter"]], "is_mc_tree() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.is_mc_tree"]], "is_pulse_map() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.is_pulse_map"]], "merge_files() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.merge_files"]], "save_data() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.save_data"]], "attach_index() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.attach_index"]], "create_table() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.create_table"]], "create_table_and_save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.create_table_and_save_to_sql"]], "database_exists() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.database_exists"]], "database_table_exists() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.database_table_exists"]], "graphnet.data.sqlite.sqlite_utilities": [[36, "module-graphnet.data.sqlite.sqlite_utilities"]], "run_sql_code() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.run_sql_code"]], "save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.save_to_sql"]], "graphnet.data.utilities": [[37, "module-graphnet.data.utilities"]], "parquettosqliteconverter (class in graphnet.data.utilities.parquet_to_sqlite)": [[38, "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter"]], "graphnet.data.utilities.parquet_to_sqlite": [[38, "module-graphnet.data.utilities.parquet_to_sqlite"]], "run() (graphnet.data.utilities.parquet_to_sqlite.parquettosqliteconverter method)": [[38, "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter.run"]], "graphnet.data.utilities.random": [[39, "module-graphnet.data.utilities.random"]], "pairwise_shuffle() (in module graphnet.data.utilities.random)": [[39, "graphnet.data.utilities.random.pairwise_shuffle"]], "stringselectionresolver (class in graphnet.data.utilities.string_selection_resolver)": [[40, "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver"]], "graphnet.data.utilities.string_selection_resolver": [[40, "module-graphnet.data.utilities.string_selection_resolver"]], "resolve() (graphnet.data.utilities.string_selection_resolver.stringselectionresolver method)": [[40, "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver.resolve"]], "graphnet.deployment": [[41, "module-graphnet.deployment"]], "graphnet.pisa": [[74, "module-graphnet.pisa"]], "contourfitter (class in graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.ContourFitter"]], "weightfitter (class in graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.WeightFitter"]], "config_updater() (in module graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.config_updater"]], "fit_1d_contour() (graphnet.pisa.fitting.contourfitter method)": [[75, "graphnet.pisa.fitting.ContourFitter.fit_1d_contour"]], "fit_2d_contour() (graphnet.pisa.fitting.contourfitter method)": [[75, "graphnet.pisa.fitting.ContourFitter.fit_2d_contour"]], "fit_weights() (graphnet.pisa.fitting.weightfitter method)": [[75, "graphnet.pisa.fitting.WeightFitter.fit_weights"]], "graphnet.pisa.fitting": [[75, "module-graphnet.pisa.fitting"]], "graphnet.pisa.plotting": [[76, "module-graphnet.pisa.plotting"]], "plot_1d_contour() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.plot_1D_contour"]], "plot_2d_contour() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.plot_2D_contour"]], "read_entry() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.read_entry"]], "graphnet.training": [[77, "module-graphnet.training"]], "bjoernlow (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.BjoernLow"]], "uniform (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.Uniform"]], "weightfitter (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.WeightFitter"]], "fit() (graphnet.training.weight_fitting.weightfitter method)": [[82, "graphnet.training.weight_fitting.WeightFitter.fit"]], "graphnet.training.weight_fitting": [[82, "module-graphnet.training.weight_fitting"]], "graphnet.utilities": [[83, "module-graphnet.utilities"]], "argumentparser (class in graphnet.utilities.argparse)": [[84, "graphnet.utilities.argparse.ArgumentParser"]], "options (class in graphnet.utilities.argparse)": [[84, "graphnet.utilities.argparse.Options"]], "contains() (graphnet.utilities.argparse.options method)": [[84, "graphnet.utilities.argparse.Options.contains"]], "graphnet.utilities.argparse": [[84, "module-graphnet.utilities.argparse"]], "pop_default() (graphnet.utilities.argparse.options method)": [[84, "graphnet.utilities.argparse.Options.pop_default"]], "standard_arguments (graphnet.utilities.argparse.argumentparser attribute)": [[84, "graphnet.utilities.argparse.ArgumentParser.standard_arguments"]], "with_standard_arguments() (graphnet.utilities.argparse.argumentparser method)": [[84, "graphnet.utilities.argparse.ArgumentParser.with_standard_arguments"]], "graphnet.utilities.decorators": [[92, "module-graphnet.utilities.decorators"]], "find_i3_files() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.find_i3_files"]], "graphnet.utilities.filesys": [[93, "module-graphnet.utilities.filesys"]], "has_extension() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.has_extension"]], "is_gcd_file() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.is_gcd_file"]], "is_i3_file() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.is_i3_file"]], "graphnet.utilities.imports": [[94, "module-graphnet.utilities.imports"]], "has_icecube_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_icecube_package"]], "has_pisa_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_pisa_package"]], "has_torch_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_torch_package"]], "requires_icecube() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.requires_icecube"]], "logger (class in graphnet.utilities.logging)": [[95, "graphnet.utilities.logging.Logger"]], "repeatfilter (class in graphnet.utilities.logging)": [[95, "graphnet.utilities.logging.RepeatFilter"]], "critical() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.critical"]], "debug() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.debug"]], "error() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.error"]], "file_handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.file_handlers"]], "filter() (graphnet.utilities.logging.repeatfilter method)": [[95, "graphnet.utilities.logging.RepeatFilter.filter"]], "graphnet.utilities.logging": [[95, "module-graphnet.utilities.logging"]], "handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.handlers"]], "info() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.info"]], "nb_repeats_allowed (graphnet.utilities.logging.repeatfilter attribute)": [[95, "graphnet.utilities.logging.RepeatFilter.nb_repeats_allowed"]], "setlevel() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.setLevel"]], "stream_handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.stream_handlers"]], "warning() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.warning"]], "warning_once() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.warning_once"]]}}) \ No newline at end of file +Search.setIndex({"docnames": ["about", "api/graphnet", "api/graphnet.constants", "api/graphnet.data", "api/graphnet.data.constants", "api/graphnet.data.dataconverter", "api/graphnet.data.dataloader", "api/graphnet.data.dataset", "api/graphnet.data.dataset.dataset", "api/graphnet.data.dataset.parquet", "api/graphnet.data.dataset.parquet.parquet_dataset", "api/graphnet.data.dataset.sqlite", "api/graphnet.data.dataset.sqlite.sqlite_dataset", "api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed", "api/graphnet.data.extractors", "api/graphnet.data.extractors.i3extractor", "api/graphnet.data.extractors.i3featureextractor", "api/graphnet.data.extractors.i3genericextractor", "api/graphnet.data.extractors.i3hybridrecoextractor", "api/graphnet.data.extractors.i3ntmuonlabelsextractor", "api/graphnet.data.extractors.i3particleextractor", "api/graphnet.data.extractors.i3pisaextractor", "api/graphnet.data.extractors.i3quesoextractor", "api/graphnet.data.extractors.i3retroextractor", "api/graphnet.data.extractors.i3splinempeextractor", "api/graphnet.data.extractors.i3truthextractor", "api/graphnet.data.extractors.i3tumextractor", "api/graphnet.data.extractors.utilities", "api/graphnet.data.extractors.utilities.collections", "api/graphnet.data.extractors.utilities.frames", "api/graphnet.data.extractors.utilities.types", "api/graphnet.data.parquet", "api/graphnet.data.parquet.parquet_dataconverter", "api/graphnet.data.pipeline", "api/graphnet.data.sqlite", "api/graphnet.data.sqlite.sqlite_dataconverter", "api/graphnet.data.sqlite.sqlite_utilities", "api/graphnet.data.utilities", "api/graphnet.data.utilities.parquet_to_sqlite", "api/graphnet.data.utilities.random", "api/graphnet.data.utilities.string_selection_resolver", "api/graphnet.deployment", "api/graphnet.deployment.i3modules", "api/graphnet.deployment.i3modules.deployer", "api/graphnet.deployment.i3modules.graphnet_module", "api/graphnet.models", "api/graphnet.models.coarsening", "api/graphnet.models.components", "api/graphnet.models.components.layers", "api/graphnet.models.components.pool", "api/graphnet.models.detector", "api/graphnet.models.detector.detector", "api/graphnet.models.detector.icecube", "api/graphnet.models.detector.prometheus", "api/graphnet.models.gnn", "api/graphnet.models.gnn.convnet", "api/graphnet.models.gnn.dynedge", "api/graphnet.models.gnn.dynedge_jinst", "api/graphnet.models.gnn.dynedge_kaggle_tito", "api/graphnet.models.gnn.gnn", "api/graphnet.models.graphs", "api/graphnet.models.graphs.edges", "api/graphnet.models.graphs.edges.edges", "api/graphnet.models.graphs.graph_definition", "api/graphnet.models.graphs.graphs", "api/graphnet.models.graphs.nodes", "api/graphnet.models.graphs.nodes.nodes", "api/graphnet.models.model", "api/graphnet.models.standard_model", "api/graphnet.models.task", "api/graphnet.models.task.classification", "api/graphnet.models.task.reconstruction", "api/graphnet.models.task.task", "api/graphnet.models.utils", "api/graphnet.pisa", "api/graphnet.pisa.fitting", "api/graphnet.pisa.plotting", "api/graphnet.training", "api/graphnet.training.callbacks", "api/graphnet.training.labels", "api/graphnet.training.loss_functions", "api/graphnet.training.utils", "api/graphnet.training.weight_fitting", "api/graphnet.utilities", "api/graphnet.utilities.argparse", "api/graphnet.utilities.config", "api/graphnet.utilities.config.base_config", "api/graphnet.utilities.config.configurable", "api/graphnet.utilities.config.dataset_config", "api/graphnet.utilities.config.model_config", "api/graphnet.utilities.config.parsing", "api/graphnet.utilities.config.training_config", "api/graphnet.utilities.decorators", "api/graphnet.utilities.filesys", "api/graphnet.utilities.imports", "api/graphnet.utilities.logging", "api/graphnet.utilities.maths", "api/modules", "contribute", "index", "install"], "filenames": ["about.md", "api/graphnet.rst", "api/graphnet.constants.rst", "api/graphnet.data.rst", "api/graphnet.data.constants.rst", "api/graphnet.data.dataconverter.rst", "api/graphnet.data.dataloader.rst", "api/graphnet.data.dataset.rst", "api/graphnet.data.dataset.dataset.rst", "api/graphnet.data.dataset.parquet.rst", "api/graphnet.data.dataset.parquet.parquet_dataset.rst", "api/graphnet.data.dataset.sqlite.rst", "api/graphnet.data.dataset.sqlite.sqlite_dataset.rst", "api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.rst", "api/graphnet.data.extractors.rst", "api/graphnet.data.extractors.i3extractor.rst", "api/graphnet.data.extractors.i3featureextractor.rst", "api/graphnet.data.extractors.i3genericextractor.rst", "api/graphnet.data.extractors.i3hybridrecoextractor.rst", "api/graphnet.data.extractors.i3ntmuonlabelsextractor.rst", "api/graphnet.data.extractors.i3particleextractor.rst", "api/graphnet.data.extractors.i3pisaextractor.rst", "api/graphnet.data.extractors.i3quesoextractor.rst", "api/graphnet.data.extractors.i3retroextractor.rst", "api/graphnet.data.extractors.i3splinempeextractor.rst", "api/graphnet.data.extractors.i3truthextractor.rst", "api/graphnet.data.extractors.i3tumextractor.rst", "api/graphnet.data.extractors.utilities.rst", "api/graphnet.data.extractors.utilities.collections.rst", "api/graphnet.data.extractors.utilities.frames.rst", "api/graphnet.data.extractors.utilities.types.rst", "api/graphnet.data.parquet.rst", "api/graphnet.data.parquet.parquet_dataconverter.rst", "api/graphnet.data.pipeline.rst", "api/graphnet.data.sqlite.rst", "api/graphnet.data.sqlite.sqlite_dataconverter.rst", "api/graphnet.data.sqlite.sqlite_utilities.rst", "api/graphnet.data.utilities.rst", "api/graphnet.data.utilities.parquet_to_sqlite.rst", "api/graphnet.data.utilities.random.rst", "api/graphnet.data.utilities.string_selection_resolver.rst", "api/graphnet.deployment.rst", "api/graphnet.deployment.i3modules.rst", "api/graphnet.deployment.i3modules.deployer.rst", "api/graphnet.deployment.i3modules.graphnet_module.rst", "api/graphnet.models.rst", "api/graphnet.models.coarsening.rst", "api/graphnet.models.components.rst", "api/graphnet.models.components.layers.rst", "api/graphnet.models.components.pool.rst", "api/graphnet.models.detector.rst", "api/graphnet.models.detector.detector.rst", "api/graphnet.models.detector.icecube.rst", "api/graphnet.models.detector.prometheus.rst", "api/graphnet.models.gnn.rst", "api/graphnet.models.gnn.convnet.rst", "api/graphnet.models.gnn.dynedge.rst", "api/graphnet.models.gnn.dynedge_jinst.rst", "api/graphnet.models.gnn.dynedge_kaggle_tito.rst", "api/graphnet.models.gnn.gnn.rst", "api/graphnet.models.graphs.rst", "api/graphnet.models.graphs.edges.rst", "api/graphnet.models.graphs.edges.edges.rst", "api/graphnet.models.graphs.graph_definition.rst", "api/graphnet.models.graphs.graphs.rst", "api/graphnet.models.graphs.nodes.rst", "api/graphnet.models.graphs.nodes.nodes.rst", "api/graphnet.models.model.rst", "api/graphnet.models.standard_model.rst", "api/graphnet.models.task.rst", "api/graphnet.models.task.classification.rst", "api/graphnet.models.task.reconstruction.rst", "api/graphnet.models.task.task.rst", "api/graphnet.models.utils.rst", "api/graphnet.pisa.rst", "api/graphnet.pisa.fitting.rst", "api/graphnet.pisa.plotting.rst", "api/graphnet.training.rst", "api/graphnet.training.callbacks.rst", "api/graphnet.training.labels.rst", "api/graphnet.training.loss_functions.rst", "api/graphnet.training.utils.rst", "api/graphnet.training.weight_fitting.rst", "api/graphnet.utilities.rst", "api/graphnet.utilities.argparse.rst", "api/graphnet.utilities.config.rst", "api/graphnet.utilities.config.base_config.rst", "api/graphnet.utilities.config.configurable.rst", "api/graphnet.utilities.config.dataset_config.rst", "api/graphnet.utilities.config.model_config.rst", "api/graphnet.utilities.config.parsing.rst", "api/graphnet.utilities.config.training_config.rst", "api/graphnet.utilities.decorators.rst", "api/graphnet.utilities.filesys.rst", "api/graphnet.utilities.imports.rst", "api/graphnet.utilities.logging.rst", "api/graphnet.utilities.maths.rst", "api/modules.rst", "contribute.md", "index.rst", "install.md"], "titles": ["About", "API", "constants", "data", "constants", "dataconverter", "dataloader", "dataset", "dataset", "parquet", "parquet_dataset", "sqlite", "sqlite_dataset", "sqlite_dataset_perturbed", "extractors", "i3extractor", "i3featureextractor", "i3genericextractor", "i3hybridrecoextractor", "i3ntmuonlabelsextractor", "i3particleextractor", "i3pisaextractor", "i3quesoextractor", "i3retroextractor", "i3splinempeextractor", "i3truthextractor", "i3tumextractor", "utilities", "collections", "frames", "types", "parquet", "parquet_dataconverter", "pipeline", "sqlite", "sqlite_dataconverter", "sqlite_utilities", "utilities", "parquet_to_sqlite", "random", "string_selection_resolver", "deployment", "i3modules", "deployer", "graphnet_module", "models", "coarsening", "components", "layers", "pool", "detector", "detector", "icecube", "prometheus", "gnn", "convnet", "dynedge", "dynedge_jinst", "dynedge_kaggle_tito", "gnn", "graphs", "edges", "edges", "graph_definition", "graphs", "nodes", "nodes", "model", "standard_model", "task", "classification", "reconstruction", "task", "utils", "pisa", "fitting", "plotting", "training", "callbacks", "labels", "loss_functions", "utils", "weight_fitting", "utilities", "argparse", "config", "base_config", "configurable", "dataset_config", "model_config", "parsing", "training_config", "decorators", "filesys", "imports", "logging", "maths", "src", "Contribute", "About", "Install"], "terms": {"graphnet": [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 33, 35, 36, 37, 38, 39, 40, 41, 44, 45, 46, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96, 98, 99, 100], "i": [0, 1, 8, 10, 12, 13, 15, 17, 28, 29, 30, 35, 36, 39, 40, 44, 46, 49, 55, 56, 62, 66, 70, 71, 72, 73, 76, 78, 79, 80, 82, 84, 89, 90, 93, 94, 95, 98, 99, 100], "an": [0, 5, 30, 32, 33, 35, 40, 44, 63, 80, 93, 95, 98, 99, 100], "open": [0, 98, 99], "sourc": [0, 4, 5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 33, 35, 36, 38, 39, 40, 44, 46, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 75, 76, 78, 79, 80, 81, 82, 84, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96, 98, 99], "python": [0, 1, 5, 14, 15, 17, 28, 30, 98, 99, 100], "framework": [0, 99], "aim": [0, 1, 98, 99], "provid": [0, 1, 8, 10, 12, 13, 44, 45, 80, 98, 99, 100], "high": [0, 99], "qualiti": [0, 99], "user": [0, 45, 78, 99, 100], "friendli": [0, 99], "end": [0, 1, 5, 32, 35, 99], "function": [0, 5, 6, 8, 30, 36, 39, 44, 46, 49, 52, 53, 63, 67, 70, 71, 72, 73, 75, 76, 80, 81, 83, 88, 89, 90, 93, 94, 96, 99], "perform": [0, 46, 48, 49, 54, 56, 58, 68, 70, 71, 72, 99], "reconstruct": [0, 1, 16, 18, 19, 23, 24, 26, 33, 41, 45, 58, 69, 72, 99], "task": [0, 1, 45, 68, 70, 71, 80, 98, 99], "neutrino": [0, 1, 48, 58, 75, 99], "telescop": [0, 1, 99], "us": [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 15, 20, 25, 27, 28, 32, 33, 35, 36, 37, 38, 40, 41, 44, 45, 48, 49, 51, 56, 57, 58, 62, 63, 64, 67, 69, 70, 71, 72, 73, 75, 78, 79, 80, 82, 83, 84, 85, 86, 88, 89, 90, 91, 94, 95, 98, 99, 100], "graph": [0, 1, 6, 8, 10, 12, 13, 44, 45, 48, 49, 51, 61, 62, 63, 65, 66, 73, 79, 81, 98, 99], "neural": [0, 1, 99], "network": [0, 1, 55, 99], "gnn": [0, 1, 33, 45, 55, 56, 57, 58, 63, 68, 99, 100], "make": [0, 5, 82, 88, 89, 98, 99, 100], "fast": [0, 99, 100], "easi": [0, 99], "train": [0, 1, 7, 13, 40, 41, 44, 63, 68, 78, 79, 80, 81, 82, 84, 88, 89, 91, 97, 99, 100], "complex": [0, 45, 99], "model": [0, 1, 13, 41, 44, 46, 47, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 68, 69, 70, 71, 72, 73, 76, 77, 78, 80, 81, 84, 86, 88, 89, 91, 97, 99, 100], "can": [0, 1, 8, 10, 12, 13, 15, 17, 20, 38, 44, 49, 63, 75, 76, 82, 84, 86, 88, 89, 98, 99, 100], "event": [0, 1, 8, 10, 12, 13, 22, 36, 38, 40, 44, 49, 63, 70, 71, 72, 73, 75, 80, 82, 88, 99], "state": [0, 99], "art": [0, 99], "arbitrari": [0, 99], "detector": [0, 1, 25, 45, 52, 53, 63, 64, 66, 68, 99], "configur": [0, 1, 8, 45, 67, 68, 75, 83, 85, 86, 88, 89, 91, 95, 99], "infer": [0, 1, 33, 41, 44, 68, 70, 71, 72, 99, 100], "time": [0, 4, 36, 46, 49, 71, 95, 99, 100], "ar": [0, 1, 4, 5, 8, 10, 12, 13, 17, 30, 32, 35, 38, 40, 44, 49, 56, 58, 60, 61, 62, 63, 64, 65, 70, 75, 80, 82, 88, 89, 98, 99, 100], "order": [0, 28, 46, 73, 99], "magnitud": [0, 99], "faster": [0, 99], "than": [0, 6, 70, 71, 72, 81, 95, 99], "tradit": [0, 99], "techniqu": [0, 99], "common": [0, 1, 80, 86, 91, 92, 94, 99], "ml": [0, 1, 99], "develop": [0, 1, 98, 99, 100], "physicist": [0, 1, 99], "wish": [0, 98, 99], "tool": [0, 1, 99], "research": [0, 99], "By": [0, 38, 70, 71, 72, 99], "unit": [0, 5, 94, 98, 99], "both": [0, 17, 70, 71, 72, 76, 99], "group": [0, 5, 32, 35, 49, 99], "increas": [0, 78, 99], "longev": [0, 99], "usabl": [0, 99], "individu": [0, 5, 8, 10, 12, 13, 49, 56, 73, 99], "code": [0, 25, 36, 63, 88, 89, 99], "contribut": [0, 99, 100], "from": [0, 1, 6, 8, 10, 12, 13, 14, 15, 17, 19, 20, 22, 28, 29, 30, 33, 35, 38, 44, 49, 58, 62, 63, 66, 67, 70, 71, 72, 73, 76, 78, 79, 80, 86, 87, 88, 89, 91, 95, 98, 99, 100], "build": [0, 1, 45, 51, 62, 66, 67, 86, 88, 89, 99], "gener": [0, 5, 8, 10, 12, 13, 17, 44, 60, 61, 65, 70, 80, 99], "reusabl": [0, 99], "softwar": [0, 80, 99], "packag": [0, 1, 39, 90, 93, 94, 98, 99, 100], "base": [0, 4, 5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 33, 35, 38, 40, 44, 46, 48, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 75, 78, 79, 80, 82, 84, 86, 87, 88, 89, 91, 94, 95, 99], "engin": [0, 99], "best": [0, 98, 99], "practic": [0, 98, 99], "lower": [0, 76, 99], "technic": [0, 99], "threshold": [0, 44, 62, 99], "most": [0, 1, 40, 99, 100], "scientif": [0, 1, 99], "problem": [0, 62, 98, 99], "The": [0, 5, 8, 10, 12, 28, 30, 33, 35, 36, 44, 46, 48, 49, 56, 58, 62, 63, 70, 71, 72, 73, 75, 76, 78, 79, 80, 99], "improv": [0, 1, 84, 99], "classif": [0, 1, 45, 69, 72, 80, 99], "yield": [0, 56, 75, 80, 99], "veri": [0, 40, 99], "accur": [0, 99], "e": [0, 1, 5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 28, 30, 32, 33, 35, 36, 40, 44, 46, 48, 49, 51, 52, 53, 55, 59, 62, 63, 66, 67, 68, 70, 71, 72, 73, 78, 79, 80, 82, 86, 95, 98, 99, 100], "g": [0, 1, 5, 8, 10, 12, 13, 25, 28, 30, 32, 33, 35, 36, 40, 44, 49, 63, 70, 71, 72, 73, 82, 95, 98, 99, 100], "low": [0, 99], "energi": [0, 4, 33, 70, 71, 72, 82, 99], "observ": [0, 99], "icecub": [0, 1, 16, 29, 30, 45, 48, 50, 58, 94, 99, 100], "here": [0, 98, 99, 100], "implement": [0, 1, 5, 15, 31, 32, 34, 35, 48, 55, 56, 57, 58, 62, 80, 98, 99], "wa": [0, 99], "appli": [0, 8, 10, 12, 13, 15, 49, 55, 56, 57, 58, 59, 68, 90, 99], "oscil": [0, 74, 99], "lead": [0, 99], "signific": [0, 99], "angular": [0, 99], "rang": [0, 70, 71, 72, 99], "relev": [0, 1, 30, 39, 93, 98, 99], "studi": [0, 99], "furthermor": [0, 99], "shown": [0, 99], "could": [0, 98, 99], "muon": [0, 19, 99], "v": [0, 99], "therebi": [0, 1, 88, 89, 99], "effici": [0, 99], "puriti": [0, 99], "sampl": [0, 40, 99], "analysi": [0, 33, 99, 100], "similarli": [0, 30, 99], "ha": [0, 5, 30, 32, 35, 36, 44, 55, 80, 93, 99, 100], "great": [0, 99], "point": [0, 24, 79, 80, 99], "analys": [0, 41, 74, 99], "final": [0, 49, 78, 88, 99], "millisecond": [0, 99], "allow": [0, 41, 45, 49, 78, 86, 91, 99, 100], "whole": [0, 99], "new": [0, 1, 35, 48, 86, 91, 98, 99], "type": [0, 5, 6, 8, 10, 12, 13, 14, 15, 27, 28, 29, 32, 35, 36, 38, 39, 40, 46, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 72, 73, 75, 76, 78, 80, 81, 82, 84, 86, 87, 88, 89, 90, 93, 94, 95, 96, 98, 99], "cosmic": [0, 99], "alert": [0, 99], "which": [0, 8, 10, 12, 13, 15, 16, 25, 29, 33, 40, 44, 46, 49, 56, 67, 70, 75, 80, 84, 99, 100], "were": [0, 99], "previous": [0, 99], "unfeas": [0, 99], "possibl": [0, 28, 98, 99], "identifi": [0, 5, 8, 10, 12, 13, 25, 88, 89, 99], "10": [0, 33, 84, 99], "tev": [0, 99], "monitor": [0, 99], "rate": [0, 78, 99], "direct": [0, 58, 70, 71, 72, 77, 79, 99], "real": [0, 99], "thi": [0, 3, 5, 8, 10, 12, 13, 15, 17, 30, 32, 35, 36, 39, 44, 45, 49, 56, 66, 68, 70, 71, 72, 73, 75, 76, 78, 80, 82, 86, 88, 89, 91, 95, 98, 99, 100], "enabl": [0, 3, 99], "first": [0, 78, 86, 91, 98, 99], "ever": [0, 99], "despit": [0, 99], "larg": [0, 80, 99], "background": [0, 99], "origin": [0, 75, 99], "compris": [0, 99], "number": [0, 5, 8, 10, 12, 13, 32, 33, 35, 40, 48, 49, 55, 56, 57, 58, 59, 62, 64, 66, 70, 71, 72, 78, 84, 99], "modul": [0, 3, 8, 30, 33, 41, 44, 45, 48, 50, 54, 60, 61, 63, 64, 65, 67, 69, 74, 77, 83, 85, 88, 89, 90, 91, 94, 99], "necessari": [0, 28, 98, 99], "workflow": [0, 99], "ingest": [0, 1, 3, 50, 99], "raw": [0, 66, 99], "data": [0, 1, 4, 5, 6, 8, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 46, 48, 49, 50, 51, 52, 55, 56, 57, 58, 59, 62, 63, 64, 67, 68, 70, 71, 72, 73, 79, 81, 84, 86, 88, 91, 94, 97, 99, 100], "domain": [0, 1, 3, 41, 99], "specif": [0, 1, 3, 5, 8, 10, 12, 16, 30, 31, 32, 34, 35, 36, 41, 46, 49, 50, 51, 52, 53, 54, 59, 62, 63, 66, 68, 69, 70, 71, 72, 80, 98, 99, 100], "format": [0, 1, 3, 5, 8, 28, 32, 35, 76, 88, 98, 99, 100], "deploi": [0, 1, 41, 44, 99], "chain": [0, 1, 41, 45, 68, 99, 100], "illustr": [0, 98, 99], "figur": [0, 76, 99], "level": [0, 8, 10, 12, 13, 25, 36, 46, 49, 62, 67, 95, 99, 100], "overview": [0, 99], "typic": [0, 28, 99], "convert": [0, 1, 3, 5, 28, 32, 35, 38, 99, 100], "industri": [0, 3, 99], "standard": [0, 3, 4, 5, 13, 32, 35, 40, 52, 53, 63, 66, 68, 84, 98, 99], "intermedi": [0, 1, 3, 5, 8, 32, 35, 55, 99, 100], "file": [0, 1, 3, 5, 8, 10, 12, 13, 15, 28, 32, 35, 38, 39, 44, 63, 67, 75, 78, 80, 84, 85, 86, 87, 88, 89, 93, 95, 99, 100], "read": [0, 3, 8, 10, 12, 13, 28, 51, 56, 68, 69, 99, 100], "simpl": [0, 45, 99], "physic": [0, 1, 15, 29, 30, 41, 44, 45, 69, 70, 71, 72, 99], "orient": [0, 45, 99], "compon": [0, 1, 45, 48, 49, 68, 99], "manag": [0, 15, 77, 99], "experi": [0, 1, 77, 99], "log": [0, 1, 71, 77, 78, 80, 83, 99, 100], "deploy": [0, 1, 42, 44, 63, 97, 99], "modular": [0, 45, 99], "subclass": [0, 45, 99], "torch": [0, 8, 10, 12, 13, 45, 48, 63, 64, 67, 68, 94, 99, 100], "nn": [0, 45, 48, 62, 64, 99], "mean": [0, 5, 8, 10, 12, 13, 32, 35, 45, 56, 58, 80, 89, 99], "onli": [0, 1, 8, 10, 12, 13, 45, 49, 70, 71, 72, 75, 82, 89, 94, 99, 100], "need": [0, 28, 45, 67, 80, 99, 100], "import": [0, 1, 36, 45, 83, 99], "few": [0, 45, 98, 99], "exist": [0, 8, 10, 12, 13, 33, 35, 36, 45, 79, 88, 99], "purpos": [0, 45, 80, 99], "built": [0, 45, 99], "them": [0, 1, 28, 45, 56, 70, 71, 72, 75, 99, 100], "togeth": [0, 45, 62, 68, 99], "form": [0, 45, 70, 86, 91, 99], "complet": [0, 45, 68, 99], "extend": [0, 1, 99], "suit": [0, 99], "through": [0, 80, 99], "layer": [0, 45, 47, 49, 55, 56, 57, 58, 70, 71, 72, 99], "connect": [0, 62, 63, 66, 80, 99], "etc": [0, 80, 95, 99], "optimis": [0, 1, 99], "differ": [0, 8, 10, 12, 13, 15, 64, 68, 98, 99, 100], "track": [0, 15, 19, 71, 98, 99], "These": [0, 63, 98, 99], "prepar": [0, 80, 99], "satisfi": [0, 99], "o": [0, 70, 71, 72, 99], "load": [0, 6, 8, 39, 67, 86, 88, 99], "requir": [0, 21, 36, 70, 80, 88, 89, 91, 99, 100], "when": [0, 5, 8, 10, 12, 13, 28, 32, 35, 36, 44, 48, 56, 58, 79, 95, 98, 99, 100], "batch": [0, 6, 33, 46, 48, 49, 68, 73, 81, 84, 99], "do": [0, 44, 80, 88, 89, 98, 99, 100], "predict": [0, 20, 24, 26, 33, 44, 55, 67, 68, 70, 71, 72, 80, 81, 99], "either": [0, 8, 10, 12, 80, 99, 100], "contain": [0, 5, 8, 10, 12, 13, 28, 29, 32, 33, 35, 44, 56, 60, 61, 63, 64, 65, 67, 70, 71, 72, 80, 82, 84, 99, 100], "imag": [0, 1, 98, 99, 100], "portabl": [0, 99], "depend": [0, 99, 100], "free": [0, 80, 99], "split": [0, 46, 99], "up": [0, 5, 32, 35, 44, 98, 99, 100], "interfac": [0, 74, 99, 100], "block": [0, 1, 99], "pre": [0, 13, 51, 63, 79, 98, 99], "directli": [0, 15, 99], "while": [0, 17, 78, 99], "continu": [0, 80, 99], "expand": [0, 99], "": [0, 5, 6, 8, 10, 12, 13, 15, 28, 35, 38, 55, 56, 68, 70, 71, 72, 73, 78, 82, 84, 88, 89, 95, 96, 99, 100], "capabl": [0, 99], "project": [0, 98, 99], "receiv": [0, 99], "fund": [0, 99], "european": [0, 99], "union": [0, 6, 8, 10, 12, 13, 17, 28, 30, 44, 46, 48, 49, 56, 67, 68, 70, 71, 72, 88, 91, 93, 99], "horizon": [0, 99], "2020": [0, 99], "innov": [0, 99], "programm": [0, 99], "under": [0, 13, 99], "mari": [0, 99], "sk\u0142odowska": [0, 99], "curi": [0, 99], "grant": [0, 80, 99], "agreement": [0, 98, 99], "No": [0, 99], "890778": [0, 99], "work": [0, 4, 29, 98, 99, 100], "rasmu": [0, 57, 99], "\u00f8rs\u00f8e": [0, 99], "partli": [0, 99], "punch4nfdi": [0, 99], "consortium": [0, 99], "support": [0, 30, 98, 99, 100], "dfg": [0, 99], "nfdi": [0, 99], "39": [0, 99, 100], "1": [0, 5, 8, 28, 32, 35, 40, 46, 49, 56, 58, 62, 64, 70, 71, 72, 73, 78, 80, 82, 88, 99, 100], "germani": [0, 99], "conveni": [1, 98, 100], "collabor": 1, "solv": [1, 98], "It": [1, 28, 36, 44, 98], "leverag": 1, "advanc": [1, 49], "machin": [1, 100], "learn": [1, 44, 78, 100], "without": [1, 62, 66, 75, 80, 100], "have": [1, 5, 17, 32, 35, 36, 40, 49, 63, 70, 71, 72, 98, 100], "expert": 1, "themselv": [1, 88, 89], "acceler": 1, "area": 1, "phyic": 1, "design": 1, "principl": 1, "all": [1, 5, 8, 10, 12, 13, 15, 17, 32, 35, 36, 44, 48, 49, 51, 56, 59, 63, 67, 72, 80, 86, 87, 88, 89, 90, 91, 95, 98, 100], "streamlin": 1, "process": [1, 5, 13, 15, 44, 51, 56, 98, 100], "transform": [1, 49, 70, 71, 72, 82], "extens": [1, 93], "basic": 1, "across": [1, 2, 8, 10, 12, 13, 30, 37, 49, 68, 80, 83, 84, 85, 95], "variou": 1, "easili": 1, "architectur": [1, 55, 56, 57, 58, 68], "main": [1, 54, 63, 68, 98, 100], "featur": [1, 3, 4, 5, 8, 10, 12, 13, 16, 33, 44, 48, 49, 51, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 70, 73, 81, 88, 98], "i3": [1, 5, 15, 29, 30, 32, 35, 39, 44, 93, 100], "more": [1, 8, 36, 39, 86, 88, 89, 91, 95], "index": [1, 5, 8, 10, 12, 30, 36, 49, 78], "sqlite": [1, 3, 7, 12, 13, 33, 35, 36, 38, 100], "suitabl": 1, "plug": 1, "plai": 1, "abstract": [1, 5, 8, 51, 59, 63, 67, 72, 87], "awai": 1, "detail": [1, 100], "expos": 1, "physicst": 1, "what": [1, 63, 98], "i3modul": [1, 41, 44], "includ": [1, 13, 67, 68, 75, 80, 86, 98], "docker": 1, "run": [1, 38], "containeris": 1, "fashion": 1, "subpackag": [1, 3, 7, 14, 41, 45, 60, 83], "dataset": [1, 3, 6, 9, 10, 11, 12, 13, 19, 40, 63, 84, 88], "extractor": [1, 3, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 35, 44], "parquet": [1, 3, 7, 10, 32, 38, 100], "util": [1, 3, 14, 28, 29, 30, 36, 38, 39, 40, 45, 77, 84, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96, 97], "constant": [1, 3, 97], "dataconvert": [1, 3, 32, 35], "dataload": [1, 3, 33, 63, 67, 68, 81, 91], "pipelin": [1, 3], "coarsen": [1, 45, 49], "standard_model": [1, 45], "pisa": [1, 21, 33, 75, 76, 94, 97, 100], "fit": [1, 67, 74, 76, 80, 82, 91], "plot": [1, 74], "callback": [1, 67, 77], "label": [1, 8, 19, 22, 55, 63, 68, 72, 76, 77, 81], "loss_funct": [1, 70, 71, 72, 77], "weight_fit": [1, 77], "config": [1, 6, 40, 75, 80, 83, 84, 86, 87, 88, 89, 90, 91], "argpars": [1, 83], "decor": [1, 5, 83, 94], "filesi": [1, 83], "math": [1, 83], "submodul": [1, 3, 7, 9, 11, 14, 27, 31, 34, 37, 42, 45, 47, 50, 54, 60, 61, 65, 69, 74, 77, 83, 85, 90], "global": [2, 4, 56, 58, 67], "i3extractor": [3, 5, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35], "i3featureextractor": [3, 4, 14, 35, 44], "i3genericextractor": [3, 14, 35], "i3hybridrecoextractor": [3, 14], "i3ntmuonlabelsextractor": [3, 14], "i3particleextractor": [3, 14], "i3pisaextractor": [3, 14], "i3quesoextractor": [3, 14], "i3retroextractor": [3, 14], "i3splinempeextractor": [3, 14], "i3truthextractor": [3, 4, 14], "i3tumextractor": [3, 14], "parquet_dataconvert": [3, 31], "sqlite_dataconvert": [3, 34], "sqlite_util": [3, 34], "parquet_to_sqlit": [3, 37], "random": [3, 8, 10, 12, 13, 37, 40, 88], "string_selection_resolv": [3, 37], "truth": [3, 4, 8, 10, 12, 13, 16, 25, 33, 36, 63, 81, 82, 88], "fileset": [3, 5], "init_global_index": [3, 5], "cache_output_fil": [3, 5], "collate_fn": [3, 6, 77, 81], "do_shuffl": [3, 6], "insqlitepipelin": [3, 33], "class": [4, 5, 6, 7, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 30, 31, 32, 33, 34, 35, 38, 40, 44, 46, 48, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 75, 78, 79, 80, 82, 84, 86, 87, 88, 89, 90, 91, 95, 98], "object": [4, 5, 8, 10, 12, 13, 15, 17, 28, 30, 44, 49, 51, 63, 70, 71, 72, 75, 84, 95], "namespac": [4, 67], "name": [4, 5, 6, 8, 10, 12, 13, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 30, 32, 33, 35, 36, 38, 44, 62, 63, 64, 66, 67, 70, 71, 72, 75, 79, 82, 84, 86, 88, 89, 90, 91, 95, 98, 100], "icecube86": [4, 50, 52], "dom_x": [4, 44, 46], "dom_i": [4, 44, 46], "dom_z": [4, 44, 46], "dom_tim": 4, "charg": [4, 44, 80], "rde": [4, 46], "pmt_area": [4, 46], "deepcor": [4, 16, 52], "upgrad": [4, 16, 52, 100], "string": [4, 5, 8, 10, 12, 13, 28, 32, 35, 40, 49, 86], "pmt_number": 4, "dom_numb": 4, "pmt_dir_x": 4, "pmt_dir_i": 4, "pmt_dir_z": 4, "dom_typ": 4, "prometheu": [4, 45, 50], "sensor_pos_x": 4, "sensor_pos_i": 4, "sensor_pos_z": 4, "t": [4, 30, 36, 76, 78, 80, 100], "kaggl": [4, 48, 52, 58], "x": [4, 5, 25, 32, 35, 48, 49, 66, 67, 72, 73, 76, 80, 82], "y": [4, 25, 73, 76, 100], "z": [4, 5, 25, 32, 35, 73, 100], "auxiliari": 4, "energy_track": 4, "position_x": 4, "position_i": 4, "position_z": 4, "azimuth": [4, 71, 79], "zenith": [4, 71, 79], "pid": [4, 40, 88], "elast": 4, "sim_typ": 4, "interaction_typ": 4, "interaction_tim": [4, 71], "inelast": [4, 71], "stopped_muon": 4, "injection_energi": 4, "injection_typ": 4, "injection_interaction_typ": 4, "injection_zenith": 4, "injection_azimuth": 4, "injection_bjorkenx": 4, "injection_bjorkeni": 4, "injection_position_x": 4, "injection_position_i": 4, "injection_position_z": 4, "injection_column_depth": 4, "primary_lepton_1_typ": 4, "primary_hadron_1_typ": 4, "primary_lepton_1_position_x": 4, "primary_lepton_1_position_i": 4, "primary_lepton_1_position_z": 4, "primary_hadron_1_position_x": 4, "primary_hadron_1_position_i": 4, "primary_hadron_1_position_z": 4, "primary_lepton_1_direction_theta": 4, "primary_lepton_1_direction_phi": 4, "primary_hadron_1_direction_theta": 4, "primary_hadron_1_direction_phi": 4, "primary_lepton_1_energi": 4, "primary_hadron_1_energi": 4, "total_energi": 4, "i3_fil": [5, 15], "str": [5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 33, 35, 36, 38, 39, 40, 44, 46, 48, 49, 51, 52, 53, 56, 58, 62, 63, 64, 66, 67, 68, 70, 71, 72, 75, 79, 81, 82, 84, 86, 87, 88, 89, 90, 91, 93, 95], "gcd_file": [5, 15, 44], "paramet": [5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 33, 35, 36, 38, 39, 40, 44, 46, 48, 49, 51, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 75, 76, 78, 79, 80, 81, 82, 84, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96], "output_fil": [5, 32, 35], "global_index": 5, "avail": [5, 17, 33, 94], "pool": [5, 45, 46, 47, 56, 58], "worker": [5, 32, 33, 35, 39, 84, 95], "return": [5, 6, 8, 10, 12, 13, 15, 28, 29, 30, 32, 35, 36, 38, 39, 40, 46, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 66, 67, 68, 70, 72, 73, 75, 76, 78, 79, 80, 81, 82, 84, 86, 87, 88, 89, 90, 93, 94, 95, 96], "none": [5, 6, 8, 10, 12, 13, 15, 17, 25, 29, 30, 32, 33, 35, 36, 38, 40, 44, 46, 48, 49, 56, 58, 62, 63, 64, 66, 67, 68, 70, 71, 72, 75, 78, 80, 81, 82, 84, 86, 87, 88, 90, 93, 95], "synchron": 5, "list": [5, 6, 8, 10, 12, 13, 15, 17, 25, 28, 30, 32, 33, 35, 36, 38, 39, 40, 44, 46, 48, 49, 51, 56, 58, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 76, 78, 80, 81, 82, 88, 90, 91, 93, 95], "process_method": 5, "cach": 5, "output": [5, 32, 35, 38, 55, 56, 57, 59, 66, 67, 68, 75, 82, 88, 89, 100], "typevar": 5, "f": [5, 49], "bound": [5, 76], "callabl": [5, 6, 8, 30, 48, 49, 51, 52, 53, 63, 70, 71, 72, 81, 82, 86, 88, 89, 90, 94], "ani": [5, 6, 8, 10, 12, 28, 29, 30, 32, 35, 44, 48, 49, 56, 62, 63, 67, 68, 70, 72, 76, 80, 82, 84, 86, 87, 88, 89, 90, 91, 95, 100], "outdir": [5, 32, 33, 35, 38, 75], "gcd_rescu": [5, 32, 35, 93], "nb_files_to_batch": [5, 32, 35], "sequential_batch_pattern": [5, 32, 35], "input_file_batch_pattern": [5, 32, 35], "index_column": [5, 8, 10, 12, 13, 32, 35, 36, 40, 75, 81, 82, 88], "icetray_verbos": [5, 32, 35], "abc": [5, 8, 15, 33, 67, 79, 82, 87], "logger": [5, 8, 15, 33, 38, 40, 62, 67, 79, 82, 83, 95, 100], "construct": [5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35, 38, 40, 46, 47, 48, 51, 52, 53, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 70, 71, 72, 75, 78, 79, 80, 81, 82, 84, 87, 88, 89, 95], "regular": [5, 30, 32, 35], "express": [5, 32, 35, 67, 80], "accord": [5, 13, 32, 35, 46, 49, 62], "match": [5, 32, 35, 82, 93, 96], "certain": [5, 32, 35, 38, 75], "pattern": [5, 32, 35], "wildcard": [5, 32, 35], "same": [5, 30, 32, 35, 36, 46, 49, 70, 73, 78, 90, 95], "input": [5, 8, 10, 12, 13, 32, 33, 35, 44, 52, 55, 56, 57, 58, 59, 63, 66, 70, 72, 73, 86, 91], "replac": [5, 32, 35, 86, 88, 89, 91], "period": [5, 32, 35], "special": [5, 17, 32, 35, 44, 73], "interpret": [5, 32, 35, 70], "liter": [5, 32, 35], "charact": [5, 32, 35], "regex": [5, 32, 35], "For": [5, 30, 32, 35, 78], "instanc": [5, 8, 15, 25, 30, 32, 35, 44, 63, 67, 75, 79, 81, 87, 100], "A": [5, 8, 32, 33, 35, 44, 49, 64, 73, 75, 80, 82, 100], "_": [5, 32, 35], "0": [5, 8, 10, 12, 32, 35, 40, 44, 46, 49, 55, 56, 58, 62, 64, 73, 75, 76, 80, 88], "9": [5, 32, 35], "5": [5, 8, 10, 12, 32, 35, 40, 84, 100], "zst": [5, 32, 35], "find": [5, 32, 35, 93], "whose": [5, 32, 35, 44], "one": [5, 8, 32, 35, 36, 44, 49, 67, 88, 89, 93, 98, 100], "capit": [5, 32, 35], "letter": [5, 32, 35], "follow": [5, 32, 35, 56, 68, 80, 82, 98, 100], "underscor": [5, 32, 35], "five": [5, 32, 35], "upgrade_genie_step4_141020_a_000000": [5, 32, 35], "upgrade_genie_step4_141020_a_000001": [5, 32, 35], "upgrade_genie_step4_141020_a_000008": [5, 32, 35], "upgrade_genie_step4_141020_a_000009": [5, 32, 35], "would": [5, 32, 35, 98], "upgrade_genie_step4_141020_a_00000x": [5, 32, 35], "suffix": [5, 32, 35], "upgrade_genie_step4_141020_a_000010": [5, 32, 35], "separ": [5, 28, 32, 35, 78, 100], "upgrade_genie_step4_141020_a_00001x": [5, 32, 35], "int": [5, 6, 8, 10, 12, 13, 19, 22, 32, 33, 35, 40, 48, 49, 55, 56, 57, 58, 59, 62, 64, 66, 67, 68, 70, 71, 72, 73, 75, 78, 80, 81, 82, 84, 88, 91, 95], "properti": [5, 8, 15, 20, 30, 49, 59, 66, 68, 72, 79, 87, 95], "file_suffix": [5, 32, 35], "execut": [5, 36], "method": [5, 8, 10, 12, 15, 27, 28, 29, 30, 32, 35, 44, 48, 49, 71, 80, 82], "set": [5, 17, 70, 71, 72, 98], "inherit": [5, 15, 30, 51, 66, 80, 95], "path": [5, 8, 10, 12, 13, 36, 39, 44, 63, 67, 75, 76, 84, 86, 87, 88, 93, 100], "correspond": [5, 8, 10, 12, 13, 28, 30, 35, 39, 56, 63, 82, 93, 100], "gcd": [5, 15, 29, 39, 44, 93], "save_data": [5, 32, 35], "save": [5, 15, 28, 32, 35, 36, 67, 75, 80, 81, 82, 86, 87, 88, 89, 100], "ordereddict": [5, 32, 35], "extract": [5, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 35, 38, 39, 44, 70, 71, 72], "merge_fil": [5, 32, 35], "input_fil": [5, 32, 35], "merg": [5, 32, 35, 80, 100], "result": [5, 32, 35, 49, 78, 80, 81, 90, 100], "option": [5, 8, 10, 12, 13, 25, 32, 33, 35, 44, 48, 49, 56, 58, 63, 64, 67, 70, 71, 72, 75, 76, 80, 82, 83, 84, 86, 88, 93, 100], "default": [5, 8, 10, 12, 13, 17, 25, 28, 32, 33, 35, 36, 38, 44, 48, 49, 55, 56, 57, 58, 62, 63, 64, 66, 67, 70, 71, 72, 75, 76, 78, 79, 80, 82, 84, 86, 88, 93], "current": [5, 32, 35, 40, 78, 98, 100], "rais": [5, 8, 17, 32, 67, 86, 91], "notimplementederror": [5, 32], "If": [5, 8, 17, 32, 33, 35, 67, 70, 71, 72, 75, 78, 82, 98, 100], "been": [5, 32, 44, 80, 98], "backend": [5, 9, 11, 32, 35], "question": 5, "get_map_funct": 5, "nb_file": 5, "map": [5, 8, 10, 12, 13, 16, 17, 35, 36, 44, 52, 53, 86, 88, 89, 91], "pure": [5, 14, 15, 17, 30], "multiprocess": [5, 100], "tupl": [5, 8, 10, 12, 29, 30, 48, 56, 58, 70, 71, 72, 73, 75, 76, 81, 84], "remov": [6, 81, 84], "less": [6, 81], "two": [6, 56, 75, 78, 80, 81], "dom": [6, 8, 10, 12, 13, 46, 49, 81], "hit": [6, 81], "should": [6, 8, 10, 12, 13, 15, 28, 40, 48, 49, 80, 81, 86, 88, 89, 91, 98, 100], "occur": [6, 81], "product": [6, 81], "selection_nam": 6, "check": [6, 29, 30, 35, 36, 84, 93, 94, 98, 100], "whether": [6, 29, 30, 35, 36, 56, 67, 80, 90, 93, 94], "shuffl": [6, 39, 81], "select": [6, 8, 10, 12, 13, 22, 40, 81, 82, 88, 98], "bool": [6, 29, 30, 35, 36, 40, 44, 46, 56, 67, 68, 75, 78, 80, 81, 82, 84, 90, 93, 94, 95], "batch_siz": [6, 33, 73, 81], "num_work": [6, 81], "persistent_work": [6, 81], "prefetch_factor": 6, "kwarg": [6, 48, 62, 67, 70, 72, 80, 82, 86, 95], "t_co": 6, "classmethod": [6, 8, 67, 80, 86, 87], "from_dataset_config": 6, "datasetconfig": [6, 8, 40, 85, 88], "dict": [6, 8, 13, 17, 28, 30, 33, 35, 51, 52, 53, 63, 67, 68, 75, 76, 78, 80, 81, 84, 86, 88, 89, 90, 91], "parquet_dataset": [7, 9], "sqlite_dataset": [7, 11], "sqlite_dataset_perturb": [7, 11], "columnmissingexcept": [7, 8], "load_modul": [7, 8, 67], "parse_graph_definit": [7, 8], "ensembledataset": [7, 8, 88], "except": 8, "indic": [8, 40, 49, 78, 84, 98], "miss": 8, "column": [8, 10, 12, 13, 36, 44, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 75, 82], "class_nam": [8, 62, 67, 89, 95], "cfg": 8, "graphdefinit": [8, 10, 12, 44, 60, 61, 63, 64, 65, 68, 81, 98], "graph_definit": [8, 10, 12, 44, 45, 60, 68, 81, 88], "pulsemap": [8, 10, 12, 13, 16, 35, 44, 81, 88], "node_truth": [8, 10, 12, 13, 81, 88], "truth_tabl": [8, 10, 12, 13, 75, 81, 82, 88], "node_truth_t": [8, 10, 12, 13, 81, 88], "string_select": [8, 10, 12, 13, 81, 88], "dtype": [8, 10, 12, 13, 63, 64, 96], "loss_weight_t": [8, 10, 12, 13, 81, 88], "loss_weight_column": [8, 10, 12, 13, 63, 81, 88], "loss_weight_default_valu": [8, 10, 12, 13, 63, 88], "seed": [8, 10, 12, 13, 40, 81, 88], "puls": [8, 10, 12, 13, 16, 17, 29, 30, 35, 36, 44, 46, 49, 66, 73], "seri": [8, 10, 12, 13, 16, 17, 29, 30, 36, 44], "node": [8, 10, 12, 13, 45, 46, 49, 55, 56, 58, 60, 61, 62, 63, 64, 70, 71, 72, 73], "multipl": [8, 10, 12, 13, 15, 78, 88, 95], "store": [8, 10, 12, 13, 15, 33, 36, 75, 79], "ad": [8, 10, 12, 13, 16, 56, 63, 75], "attribut": [8, 10, 12, 13, 46, 70, 71, 72], "event_no": [8, 10, 12, 13, 36, 40, 82, 88], "uniqu": [8, 10, 12, 13, 36, 38, 88], "indici": [8, 10, 12, 13, 29, 40, 80], "tabl": [8, 10, 12, 13, 15, 33, 35, 36, 63, 75, 82], "inform": [8, 10, 12, 13, 15, 17, 25, 76], "subset": [8, 10, 12, 13, 48, 56, 58], "given": [8, 10, 12, 13, 35, 49, 62, 70, 71, 72, 82, 84], "queri": [8, 10, 12, 36, 40], "pass": [8, 10, 12, 48, 55, 56, 57, 58, 59, 63, 67, 68, 70, 71, 72, 80, 82, 98], "float32": [8, 10, 12, 13, 63, 64], "tensor": [8, 10, 12, 13, 46, 48, 49, 51, 55, 56, 57, 58, 59, 66, 67, 68, 70, 71, 72, 73, 80, 96], "per": [8, 10, 12, 13, 17, 36, 49, 70, 71, 72, 80, 82], "loss": [8, 10, 12, 13, 63, 68, 70, 71, 72, 78, 80, 84], "weight": [8, 10, 12, 13, 44, 63, 70, 71, 72, 75, 80, 82, 89, 100], "also": [8, 10, 12, 13, 40, 88], "assign": [8, 10, 12, 13, 38, 46, 49, 98], "float": [8, 10, 12, 13, 44, 46, 55, 62, 63, 67, 75, 76, 78, 80, 81, 88], "note": [8, 10, 12, 13, 76, 89], "valu": [8, 10, 12, 13, 25, 28, 35, 36, 49, 63, 76, 79, 80, 84, 86], "specifi": [8, 10, 12, 13, 40, 46, 70, 71, 72, 76, 78, 100], "case": [8, 10, 12, 13, 17, 44, 49, 70, 71, 72, 100], "That": [8, 10, 12, 13, 56, 71, 79], "ignor": [8, 10, 12, 13, 30], "resolv": [8, 10, 12, 40], "10000": [8, 10, 12, 40], "20": [8, 10, 12, 40, 95], "defin": [8, 10, 12, 40, 44, 49, 60, 61, 62, 63, 65, 86, 88, 89, 91], "represent": [8, 10, 12, 30, 49, 64], "from_config": [8, 67, 87, 88, 89], "concaten": [8, 28, 56], "query_t": [8, 10, 12], "sequential_index": [8, 10, 12], "some": [8, 10, 12, 63], "out": [8, 56, 68, 69, 80, 95, 98, 100], "sequenti": 8, "len": 8, "self": [8, 63, 75, 86, 91], "_may_": 8, "_indic": 8, "entir": [8, 67], "impos": 8, "befor": [8, 56, 70, 71, 72, 78], "scalar": [8, 73, 80], "length": [8, 30, 78], "element": [8, 28, 30, 68, 73, 90], "present": [8, 84, 93, 94], "add_label": 8, "fn": [8, 30, 86, 90], "kei": [8, 17, 28, 29, 30, 35, 36, 46, 49, 79, 88, 89], "add": [8, 56, 84, 98, 100], "custom": [8, 63, 78], "concatdataset": 8, "singl": [8, 15, 49, 56, 79, 88, 89], "collect": [8, 14, 15, 27, 80, 96], "iter": 8, "parquetdataset": [9, 10], "pytorch": [10, 12, 13, 78, 100], "sqlitedataset": [11, 12, 13], "sqlitedatasetperturb": [11, 13], "databas": [12, 13, 33, 35, 36, 38, 75, 82, 100], "perturb": 13, "perturbation_dict": 13, "step": [13, 68, 78], "where": [13, 63, 64, 66, 79], "randomli": [13, 40, 89], "nois": [13, 16, 29, 44], "intend": [13, 100], "test": [13, 40, 70, 71, 72, 81, 88, 94, 98], "stabil": 13, "small": [13, 80], "chang": [13, 75, 80, 98], "dictionari": [13, 28, 29, 30, 33, 35, 63, 75, 76, 86, 88, 89, 91], "deviat": 13, "i3fram": [14, 15, 17, 29, 30, 44], "frame": [14, 15, 17, 27, 30, 35, 44], "i3extractorcollect": [14, 15], "i3featureextractoricecube86": [14, 16], "i3featureextractoricecubedeepcor": [14, 16], "i3featureextractoricecubeupgrad": [14, 16], "i3pulsenoisetruthflagicecubeupgrad": [14, 16], "i3galacticplanehybridrecoextractor": [14, 18], "i3ntmuonlabelextractor": [14, 19], "i3splinempeicextractor": [14, 24], "__call__": 15, "icetrai": [15, 29, 30, 44, 94], "keep": 15, "proven": 15, "set_fil": 15, "refer": [15, 88], "being": [15, 44, 70, 71, 72], "get": [15, 29, 78, 81, 100], "treat": 15, "86": [16, 52], "flag": [16, 44], "exclude_kei": 17, "dynam": [17, 48, 56, 57, 58], "pars": [17, 76, 83, 84, 85, 86, 91], "call": [17, 30, 35, 49, 75, 82, 95], "tri": [17, 30], "automat": [17, 80, 98], "cast": [17, 30], "done": [17, 49, 95, 98], "recurs": [17, 30, 90, 93], "each": [17, 28, 30, 36, 38, 39, 46, 49, 52, 53, 56, 58, 62, 63, 64, 66, 67, 70, 71, 72, 73, 75, 76, 78, 93], "look": [17, 100], "member": [17, 30, 88, 89, 95], "variabl": [17, 30, 56, 73, 82, 95], "signatur": [17, 30], "similar": [17, 30, 100], "handl": [17, 80, 84, 95], "hand": 17, "mc": [17, 35, 36], "tree": [17, 35], "trigger": 17, "exclud": [17, 38, 100], "valueerror": [17, 67], "hybrid": 18, "galatict": 18, "plane": [18, 80], "tum": [19, 26], "dnn": [19, 26], "padding_valu": [19, 22], "northeren": 19, "i3particl": 20, "other": [20, 36, 62, 80, 98], "algorithm": 20, "comparison": [20, 80], "quantiti": [21, 70, 71, 72, 73], "queso": 22, "retro": [23, 33], "splinemp": 24, "border": 25, "mctree": [25, 29], "ndarrai": [25, 63, 82], "arrai": [25, 28], "boundari": 25, "volum": 25, "coordin": [25, 73], "particl": [25, 36, 79], "start": [25, 98, 100], "stop": [25, 84], "within": [25, 46, 48, 49, 56, 62], "hard": 25, "i3mctre": 25, "flatten_nested_dictionari": [27, 28], "serialis": [27, 28], "transpose_list_of_dict": [27, 28], "frame_is_montecarlo": [27, 29], "frame_is_nois": [27, 29], "get_om_keys_and_pulseseri": [27, 29], "is_boost_enum": [27, 30], "is_boost_class": [27, 30], "is_icecube_class": [27, 30], "is_typ": [27, 30], "is_method": [27, 30], "break_cyclic_recurs": [27, 30], "get_member_vari": [27, 30], "cast_object_to_pure_python": [27, 30], "cast_pulse_series_to_pure_python": [27, 30], "manipul": [28, 60, 61, 65], "obj": [28, 30, 90], "parent_kei": 28, "flatten": 28, "nest": 28, "non": [28, 30, 35, 36, 80], "exampl": [28, 40, 46, 49, 80, 88, 89, 100], "d": [28, 63, 66, 98], "b": [28, 46, 49], "c": [28, 49, 80, 100], "2": [28, 49, 56, 58, 62, 64, 71, 73, 75, 76, 80, 88, 100], "a__b": 28, "applic": 28, "combin": [28, 88], "parent": 28, "__": [28, 30], "nester": 28, "json": [28, 88], "therefor": 28, "we": [28, 30, 40, 98, 100], "outer": 28, "abl": [28, 100], "de": 28, "transpos": 28, "mont": 29, "carlo": 29, "simul": [29, 44], "pulseseri": 29, "calibr": [29, 30], "gcd_dict": [29, 30], "p": [29, 35, 80], "om": [29, 30], "dataclass": 29, "i3calibr": 29, "indicesfor": 29, "boost": 30, "enum": 30, "ensur": [30, 39, 80, 95, 98, 100], "isn": 30, "return_discard": 30, "valid": [30, 40, 68, 70, 71, 72, 80, 84, 86, 91], "mangl": 30, "take": [30, 35, 49, 98], "mainli": 30, "cannot": [30, 86, 91], "trivial": [30, 72], "doe": [30, 89], "try": 30, "equival": 30, "its": 30, "like": [30, 49, 73, 80, 96, 98], "otherwis": [30, 80], "itself": [30, 70, 71, 72], "deem": 30, "wai": [30, 40, 98, 100], "optic": 30, "found": [30, 80], "parquetdataconvert": [31, 32], "module_dict": 33, "devic": 33, "retro_table_nam": 33, "n_worker": [33, 75], "pipeline_nam": 33, "creat": [33, 35, 36, 63, 86, 87, 91, 98, 100], "initialis": [33, 89], "gnn_module_for_energy_regress": 33, "modulelist": 33, "comput": [33, 68, 70, 71, 72, 73, 80], "directori": [33, 38, 75, 93], "100": [33, 100], "size": [33, 48, 49, 56, 57, 58, 84], "alreadi": [33, 36, 100], "error": [33, 80, 95, 98], "prompt": 33, "avoid": [33, 95, 98], "overwrit": [33, 78], "sqlitedataconvert": [34, 35, 100], "construct_datafram": [34, 35], "is_pulse_map": [34, 35], "is_mc_tre": [34, 35], "database_exist": [34, 36], "database_table_exist": [34, 36], "run_sql_cod": [34, 36], "save_to_sql": [34, 36], "attach_index": [34, 36], "create_t": [34, 36], "create_table_and_save_to_sql": [34, 36], "db": [35, 81], "max_table_s": 35, "maximum": [35, 49, 70, 71, 72, 84], "row": [35, 36], "exce": 35, "limit": [35, 80], "any_pulsemap_is_non_empti": 35, "data_dict": 35, "empti": [35, 44], "retriev": 35, "splitinicepuls": 35, "least": [35, 98, 100], "true": [35, 36, 44, 75, 78, 80, 82, 88, 89, 91], "becaus": [35, 39], "instead": [35, 80, 86, 91], "alwai": 35, "panda": [35, 40, 82], "datafram": [35, 36, 40, 67, 68, 75, 81, 82], "table_nam": [35, 36], "database_path": [36, 75, 82], "df": 36, "must": [36, 46, 78, 82, 98], "attach": 36, "default_typ": 36, "null": 36, "integer_primary_kei": 36, "NOT": [36, 80], "integ": [36, 56, 57, 80], "primari": 36, "Such": 36, "appropri": [36, 70, 71, 72], "expect": [36, 40, 44, 66], "doesn": 36, "parquettosqliteconvert": [37, 38], "pairwise_shuffl": [37, 39], "stringselectionresolv": [37, 40], "parquet_path": 38, "mc_truth_tabl": 38, "excluded_field": 38, "id": 38, "everi": [38, 100], "field": [38, 76, 79, 86, 88, 89, 91], "One": [38, 76], "choos": 38, "argument": [38, 82, 84, 86, 88, 89, 91], "exclude_field": 38, "database_nam": 38, "convers": [38, 100], "rng": 39, "relat": [39, 93], "i3_list": [39, 93], "gcd_list": [39, 93], "correpond": 39, "handi": 39, "even": 39, "files_list": 39, "gcd_shuffl": 39, "i3_shuffl": 39, "use_cach": 40, "flexibl": 40, "below": [40, 76, 82, 98, 100], "show": [40, 78], "involv": 40, "cover": 40, "yml": [40, 84, 88, 89], "50000": [40, 88], "ab": [40, 80, 88], "12": [40, 88], "14": [40, 88], "16": [40, 88], "13": [40, 100], "compat": 40, "syntax": [40, 80], "mai": [40, 66, 100], "fix": 40, "graphnet_modul": [41, 42], "graphneti3modul": [42, 44], "i3inferencemodul": [42, 44], "i3pulsecleanermodul": [42, 44], "pulsemap_extractor": 44, "produc": [44, 79, 82], "write": [44, 100], "constructor": 44, "knngraph": [44, 60, 64], "associ": [44, 63, 71, 80], "model_config": [44, 83, 85, 86, 88, 91], "state_dict": [44, 67], "model_nam": [44, 75], "prediction_column": [44, 67, 68, 81], "pulsmap": 44, "modelconfig": [44, 67, 85, 88, 89], "summar": 44, "Will": [44, 62], "help": [44, 84, 98], "entri": [44, 56, 76, 84], "dynedg": [44, 45, 54, 57, 58], "energy_reco": 44, "discard_empty_ev": 44, "clean": [44, 98, 100], "assum": [44, 51, 72, 73], "7": [44, 49, 75], "consid": [44, 100], "posit": [44, 49, 71], "signal": 44, "els": 44, "fals": [44, 56, 67, 75, 78, 80, 82, 88], "elimin": 44, "speed": 44, "especi": 44, "sinc": [44, 80], "further": 44, "calcul": [44, 62, 64, 68, 73, 79, 80], "convnet": [45, 54], "dynedge_jinst": [45, 54], "dynedge_kaggle_tito": [45, 54], "edg": [45, 48, 49, 56, 57, 58, 60, 63, 64, 65, 66, 73], "unbatch_edge_index": [45, 46], "attributecoarsen": [45, 46], "domcoarsen": [45, 46], "customdomcoarsen": [45, 46], "domandtimewindowcoarsen": [45, 46], "standardmodel": [45, 68], "calculate_xyzt_homophili": [45, 73], "calculate_distance_matrix": [45, 73], "knn_graph_batch": [45, 73], "oper": [46, 48, 54, 56], "cluster": [46, 48, 49, 56, 58], "local": [46, 84], "edge_index": [46, 48, 73], "vector": [46, 49, 80], "longtensor": [46, 49, 73], "mathbf": [46, 49], "ldot": [46, 49], "n": [46, 49, 80], "reduc": 46, "transfer_attribut": 46, "reduce_opt": 46, "avg": 46, "avg_pool": 46, "avg_pool_x": 46, "max": [46, 48, 56, 58, 80, 84], "max_pool": [46, 49], "max_pool_x": [46, 49], "min": [46, 49, 56, 58], "min_pool": [46, 47, 49], "min_pool_x": [46, 47, 49], "sum": [46, 49, 56, 58, 68], "sum_pool": [46, 47, 49], "sum_pool_x": [46, 47, 49], "forward": [46, 48, 51, 55, 56, 57, 58, 59, 62, 63, 66, 67, 68, 72, 80], "simplecoarsen": 46, "addit": [46, 48, 67, 68, 80, 82], "time_window": 46, "time_kei": 46, "window": 46, "dynedgeconv": [47, 48, 56], "edgeconvtito": [47, 48], "dyntran": [47, 48, 58], "sum_pool_and_distribut": [47, 49], "group_bi": [47, 49], "group_pulses_to_dom": [47, 49], "group_pulses_to_pmt": [47, 49], "std_pool_x": [47, 49], "std_pool": [47, 49], "aggr": 48, "nb_neighbor": 48, "features_subset": [48, 56, 58], "edgeconv": 48, "lightningmodul": [48, 67, 78, 95], "convolut": [48, 55, 56, 57, 58], "mlp": [48, 56], "aggreg": [48, 49], "8": [48, 49, 56, 64, 80, 98, 100], "neighbour": [48, 56, 58, 62, 64, 73], "after": [48, 56, 78, 84], "sequenc": 48, "slice": [48, 56, 58], "sparsetensor": 48, "messagepass": 48, "tito": [48, 58], "solut": [48, 58, 98], "deep": [48, 58], "competit": [48, 52, 58], "reset_paramet": 48, "reset": 48, "learnabl": [48, 54, 55, 56, 57, 58, 59], "messag": [48, 78, 95], "x_i": 48, "x_j": 48, "layer_s": 48, "n_head": 48, "dyntrans1": 48, "head": 48, "multiheadattent": 48, "just": [49, 100], "negat": 49, "cluster_index": 49, "distribut": [49, 56, 71, 80, 82], "ident": [49, 72], "pmt": 49, "f1": 49, "f2": 49, "6": [49, 76], "groupbi": 49, "3": [49, 55, 58, 71, 73, 75, 76, 80, 98, 100], "matrix": [49, 62, 73, 80], "mathbb": 49, "r": [49, 62, 100], "n_1": 49, "n_b": 49, "obtain": [49, 80], "wise": 49, "dens": 49, "fc": 49, "known": 49, "std": 49, "repres": [49, 63, 64, 66, 86, 88, 89], "averag": [49, 80], "torch_geometr": 49, "version": [49, 70, 71, 72, 78, 98, 100], "standardis": 50, "icecubekaggl": [50, 52], "icecubedeepcor": [50, 52], "icecubeupgrad": [50, 52], "ins": 51, "feature_map": [51, 52, 53], "node_featur": [51, 63], "node_feature_nam": [51, 63, 64, 66], "adjac": 51, "dimens": [52, 53, 55, 56, 58, 80], "prototyp": 53, "dynedgejinst": [54, 57], "dynedgetito": [54, 58], "author": [55, 57, 80], "martin": 55, "minh": 55, "nb_input": [55, 56, 57, 58, 59, 70, 71, 72], "nb_output": [55, 57, 59, 66, 70, 72], "nb_intermedi": 55, "dropout_ratio": 55, "128": [55, 56, 84], "fraction": 55, "drop": 55, "nb_neighbour": 56, "dynedge_layer_s": 56, "post_processing_layer_s": 56, "readout_layer_s": 56, "global_pooling_schem": [56, 58], "add_global_variables_after_pool": 56, "k": [56, 58, 62, 64, 73, 80], "nearest": [56, 58, 62, 64, 73], "latent": [56, 58, 70], "metric": [56, 58, 78], "dimenion": [56, 58], "multi": 56, "perceptron": 56, "256": 56, "336": 56, "hidden": [56, 57, 70, 72], "skip": 56, "post": 56, "_and_": 56, "As": 56, "last": [56, 70, 72, 78], "scheme": [56, 58], "altern": [56, 80, 98], "exact": [57, 80], "2209": 57, "03042": 57, "oerso": 57, "layer_size_scal": 57, "4": [57, 58, 71, 76], "scale": [57, 63, 70, 71, 72, 80], "ic": 58, "univers": 58, "south": 58, "pole": 58, "dyntrans_layer_s": 58, "core": 59, "edgedefinit": [60, 61, 62, 63, 65], "how": [60, 61, 65], "drawn": [60, 61, 64, 65], "between": [60, 61, 62, 65, 68, 73, 78, 80], "knnedg": [61, 62], "radialedg": [61, 62], "euclideanedg": [61, 62], "log_fold": [62, 67, 95], "_construct_edg": 62, "nb_nearest_neighbour": [62, 64], "definit": [62, 63, 64, 66, 67, 98], "space": [62, 82], "distanc": [62, 64, 73], "radiu": 62, "sphere": 62, "chosen": [62, 95], "centr": 62, "radial": 62, "center": 62, "sigma": 62, "euclidean": [62, 98], "see": [62, 63, 78, 98, 100], "http": [62, 63, 80, 98], "arxiv": [62, 80], "org": [62, 80, 100], "pdf": 62, "1809": 62, "06166": 62, "hold": 63, "alter": 63, "dure": [63, 70, 71, 72, 78], "node_definit": [63, 64], "edge_definit": 63, "geometri": 63, "nodedefinit": [63, 64, 65, 66], "truth_dict": 63, "custom_label_funct": 63, "loss_weight": [63, 70, 71, 72], "data_path": 63, "shape": [63, 66, 73, 80], "num_nod": 63, "github": [63, 80, 100], "com": [63, 80, 100], "team": [63, 98], "blob": [63, 80], "getting_start": 63, "md": 63, "your": [64, 98, 100], "nodesaspuls": [65, 66], "num_puls": 66, "overridden": 66, "set_number_of_input": 66, "measur": [66, 73], "cherenkov": 66, "radiat": 66, "train_dataload": 67, "val_dataload": 67, "max_epoch": 67, "gpu": [67, 68, 84, 100], "ckpt_path": 67, "log_every_n_step": 67, "gradient_clip_v": 67, "distribution_strategi": [67, 68], "trainer_kwarg": 67, "pytorch_lightn": [67, 95], "trainer": [67, 78, 81], "predict_as_datafram": [67, 68], "additional_attribut": [67, 68, 81], "save_state_dict": 67, "load_state_dict": 67, "karg": 67, "trust": 67, "enough": 67, "eval": [67, 100], "lambda": 67, "consequ": 67, "optimizer_class": 68, "optim": [68, 78], "adam": 68, "optimizer_kwarg": 68, "scheduler_class": 68, "scheduler_kwarg": 68, "scheduler_config": 68, "target_label": [68, 70, 71, 72], "target": [68, 70, 71, 72, 80, 91], "prediction_label": [68, 70, 71, 72], "configure_optim": 68, "shared_step": 68, "batch_idx": 68, "share": 68, "training_step": 68, "train_batch": 68, "validation_step": 68, "val_batch": 68, "compute_loss": [68, 70, 71, 72], "pred": [68, 72], "verbos": [68, 78], "activ": [68, 72, 98, 100], "mode": [68, 72], "deactiv": [68, 72], "multiclassclassificationtask": [69, 70], "binaryclassificationtask": [69, 70], "binaryclassificationtasklogit": [69, 70], "azimuthreconstructionwithkappa": [69, 71], "azimuthreconstruct": [69, 71], "directionreconstructionwithkappa": [69, 71], "zenithreconstruct": [69, 71], "zenithreconstructionwithkappa": [69, 71], "energyreconstruct": [69, 71], "energyreconstructionwithpow": [69, 71], "energyreconstructionwithuncertainti": [69, 71], "vertexreconstruct": [69, 71], "positionreconstruct": [69, 71], "timereconstruct": [69, 71], "inelasticityreconstruct": [69, 71], "identitytask": [69, 70, 72], "arg": [70, 72, 80, 84, 86, 91, 95], "classifi": 70, "untransform": 70, "logit": [70, 80], "affin": [70, 71, 72], "hidden_s": [70, 71, 72], "transform_prediction_and_target": [70, 71, 72], "transform_target": [70, 71, 72], "transform_infer": [70, 71, 72], "transform_support": [70, 71, 72], "binari": [70, 80], "feed": [70, 71, 72], "lossfunct": [70, 71, 72, 77, 80], "auto": [70, 71, 72], "matic": [70, 71, 72], "_pred": [70, 71, 72], "numer": [70, 71, 72], "stabl": [70, 71, 72], "log10": [70, 71, 72, 82], "rather": [70, 71, 72, 95], "conjunct": [70, 71, 72], "invers": [70, 71, 72], "recov": [70, 71, 72], "minimum": [70, 71, 72], "restrict": [70, 71, 72, 80], "invert": [70, 71, 72], "1e6": [70, 71, 72], "default_target_label": [70, 71, 72], "default_prediction_label": [70, 71, 72], "target_pr": 70, "angl": [71, 79], "kappa": [71, 80], "var": 71, "azimuth_pr": 71, "azimuth_kappa": 71, "3d": [71, 80], "vmf": 71, "dir_x_pr": 71, "dir_y_pr": 71, "dir_z_pr": 71, "direction_kappa": 71, "zenith_pr": 71, "zenith_kappa": 71, "energy_pr": 71, "uncertainti": 71, "energy_sigma": 71, "vertex": 71, "position_x_pr": 71, "position_y_pr": 71, "position_z_pr": 71, "interaction_time_pr": 71, "interact": 71, "hadron": 71, "inelasticity_pr": 71, "wrt": 72, "train_ev": 72, "xyzt": 73, "homophili": 73, "notic": [73, 80], "xyz_coord": 73, "pairwis": 73, "nb_dom": 73, "updat": [73, 75, 78], "config_updat": [74, 75], "weightfitt": [74, 75, 77, 82], "contourfitt": [74, 75], "read_entri": [74, 76], "plot_2d_contour": [74, 76], "plot_1d_contour": [74, 76], "contour": [75, 76], "config_path": 75, "new_config_path": 75, "dummy_sect": 75, "temp": 75, "dummi": 75, "section": 75, "header": 75, "configupdat": 75, "programat": 75, "statistical_fit": 75, "fit_weight": [75, 82], "config_outdir": 75, "weight_nam": [75, 82], "pisa_config_dict": 75, "add_to_databas": [75, 82], "flux": 75, "_database_path": 75, "statist": 75, "effect": [75, 78, 98], "account": 75, "systemat": 75, "hypersurfac": 75, "assumpt": 75, "regard": 75, "pipeline_path": 75, "post_fix": 75, "include_retro": 75, "fit_1d_contour": 75, "run_nam": 75, "config_dict": 75, "grid_siz": 75, "theta23_minmax": 75, "36": 75, "54": 75, "dm31_minmax": 75, "1d": [75, 76], "fit_2d_contour": 75, "2d": [75, 76, 80], "content": 76, "contour_data": 76, "xlim": 76, "ylim": 76, "0023799999999999997": 76, "0025499999999999997": 76, "chi2_critical_valu": 76, "width": 76, "height": 76, "path_to_pisa_fit_result": 76, "name_of_my_model_in_fit": 76, "legend": 76, "color": 76, "linestyl": 76, "style": [76, 98], "line": [76, 78, 84], "upper": 76, "axi": 76, "605": 76, "critic": [76, 95], "chi2": 76, "90": 76, "cl": 76, "right": [76, 80], "176": 76, "inch": 76, "388": 76, "706": 76, "abov": [76, 80, 82, 100], "352": 76, "piecewiselinearlr": [77, 78], "progressbar": [77, 78], "mseloss": [77, 80], "rmseloss": [77, 80], "logcoshloss": [77, 80], "crossentropyloss": [77, 80], "binarycrossentropyloss": [77, 80], "logcmk": [77, 80], "vonmisesfisherloss": [77, 80], "vonmisesfisher2dloss": [77, 80], "euclideandistanceloss": [77, 80], "vonmisesfisher3dloss": [77, 80], "make_dataload": [77, 81], "make_train_validation_dataload": [77, 81], "get_predict": [77, 81], "save_result": [77, 81], "uniform": [77, 82], "bjoernlow": [77, 82], "mileston": 78, "factor": 78, "last_epoch": 78, "_lrschedul": 78, "interpol": 78, "linearli": 78, "denot": 78, "multipli": 78, "closest": 78, "vice": 78, "versa": 78, "wrap": [78, 88, 89], "epoch": [78, 84], "print": [78, 95], "stdout": 78, "get_lr": 78, "refresh_r": 78, "process_posit": 78, "tqdmprogressbar": 78, "progress": 78, "bar": 78, "customis": 78, "lightn": 78, "init_validation_tqdm": 78, "overrid": 78, "init_predict_tqdm": 78, "init_test_tqdm": 78, "init_train_tqdm": 78, "get_metr": 78, "on_train_epoch_start": 78, "previou": 78, "behaviour": 78, "on_train_epoch_end": 78, "don": [78, 100], "duplciat": 78, "runtim": [79, 100], "azimuth_kei": 79, "zenith_kei": 79, "access": [79, 100], "azimiuth": 79, "return_el": 80, "elementwis": 80, "term": 80, "squar": 80, "root": [80, 100], "cosh": 80, "act": 80, "cross": 80, "entropi": 80, "num_class": 80, "softmax": 80, "ed": 80, "probabl": 80, "mit": 80, "licens": 80, "copyright": 80, "2019": 80, "ryabinin": 80, "permiss": 80, "herebi": 80, "person": 80, "copi": 80, "document": 80, "deal": 80, "modifi": 80, "publish": 80, "sublicens": 80, "sell": 80, "permit": 80, "whom": 80, "furnish": 80, "so": [80, 100], "subject": 80, "condit": 80, "shall": 80, "substanti": 80, "portion": 80, "THE": 80, "AS": 80, "warranti": 80, "OF": 80, "kind": 80, "OR": 80, "impli": 80, "BUT": 80, "TO": 80, "merchant": 80, "FOR": 80, "particular": [80, 98], "AND": 80, "noninfring": 80, "IN": 80, "NO": 80, "holder": 80, "BE": 80, "liabl": 80, "claim": 80, "damag": 80, "liabil": 80, "action": 80, "contract": 80, "tort": 80, "aris": 80, "WITH": 80, "_____________________": 80, "mryab": 80, "vmf_loss": 80, "master": 80, "py": [80, 100], "bessel": 80, "exponenti": 80, "ditto": 80, "iv": 80, "1812": 80, "04616": 80, "spite": 80, "suggest": 80, "sec": 80, "paper": 80, "m": 80, "correct": 80, "static": [80, 98], "ctx": 80, "backward": 80, "grad_output": 80, "von": 80, "mise": 80, "fisher": 80, "log_cmk_exact": 80, "c_": 80, "exactli": [80, 95], "log_cmk_approx": 80, "approx": 80, "minu": 80, "sign": 80, "log_cmk": 80, "kappa_switch": 80, "diverg": 80, "700": 80, "float64": 80, "precis": 80, "unaccur": 80, "switch": 80, "three": 80, "database_indic": 81, "test_siz": 81, "node_level": 81, "tag": [81, 98, 100], "archiv": 81, "public": 82, "uniformweightfitt": 82, "bin": 82, "privat": 82, "_fit_weight": 82, "sql": 82, "desir": [82, 93], "np": 82, "happen": 82, "x_low": 82, "wherea": 82, "curv": 82, "base_config": [83, 85], "dataset_config": [83, 85], "training_config": [83, 85], "argumentpars": [83, 84], "is_gcd_fil": [83, 93], "is_i3_fil": [83, 93], "has_extens": [83, 93], "find_i3_fil": [83, 93], "has_icecube_packag": [83, 94], "has_torch_packag": [83, 94], "has_pisa_packag": [83, 94], "requires_icecub": [83, 94], "repeatfilt": [83, 95], "eps_lik": [83, 96], "consist": [84, 95, 98], "cli": 84, "pop_default": 84, "usag": 84, "descript": 84, "command": [84, 100], "standard_argu": 84, "home": [84, 100], "runner": 84, "lib": [84, 100], "python3": 84, "training_example_data_sqlit": 84, "earli": 84, "patienc": 84, "narg": 84, "50": 84, "example_energy_reconstruction_model": 84, "num": 84, "fetch": 84, "with_standard_argu": 84, "overwritten": [84, 86], "baseconfig": [85, 86, 87, 88, 89, 91], "get_all_argument_valu": [85, 86], "save_dataset_config": [85, 88], "save_model_config": [85, 89], "traverse_and_appli": [85, 90], "list_all_submodul": [85, 90], "get_all_grapnet_class": [85, 90], "is_graphnet_modul": [85, 90], "is_graphnet_class": [85, 90], "get_graphnet_class": [85, 90], "trainingconfig": [85, 91], "basemodel": [86, 88, 89], "keyword": [86, 91], "validationerror": [86, 91], "pydantic_cor": [86, 91], "__init__": [86, 88, 89, 91, 100], "__pydantic_self__": [86, 91], "dump": [86, 88, 89], "yaml": [86, 87], "as_dict": [86, 88, 89], "classvar": [86, 88, 89, 91], "configdict": [86, 88, 89, 91], "conform": [86, 88, 89, 91], "pydant": [86, 88, 89, 91], "model_field": [86, 88, 89, 91], "fieldinfo": [86, 88, 89, 91], "metadata": [86, 88, 89, 91], "about": [86, 88, 89, 91], "__fields__": [86, 88, 89, 91], "v1": [86, 88, 89, 91, 100], "re": [87, 100], "save_config": 87, "dataconfig": 88, "transpar": [88, 89, 98], "reproduc": [88, 89], "In": [88, 89, 100], "session": [88, 89], "anoth": [88, 89], "you": [88, 89, 98, 100], "still": 88, "csv": 88, "train_select": 88, "test_select": 88, "unambigu": [88, 89], "annot": [88, 89, 91], "nonetyp": 88, "init_fn": [88, 89], "trainabl": 89, "hyperparamet": 89, "instanti": 89, "thu": 89, "fn_kwarg": 90, "structur": 90, "moduletyp": 90, "grapnet": 90, "lookup": 90, "early_stopping_pati": 91, "system": [93, 100], "filenam": 93, "dir": 93, "search": 93, "test_funct": 94, "filter": 95, "repeat": 95, "nb_repeats_allow": 95, "record": 95, "logrecord": 95, "clear": 95, "intuit": 95, "composit": 95, "loggeradapt": 95, "clash": 95, "setlevel": 95, "deleg": 95, "msg": 95, "warn": 95, "info": [95, 100], "debug": 95, "warning_onc": 95, "onc": 95, "handler": 95, "file_handl": 95, "filehandl": 95, "stream_handl": 95, "streamhandl": 95, "assort": 96, "ep": 96, "api": 97, "To": [98, 100], "sure": [98, 100], "smooth": 98, "guidelin": 98, "guid": 98, "encourag": 98, "contributor": 98, "discuss": 98, "bug": 98, "anyth": 98, "place": 98, "describ": 98, "yourself": 98, "ownership": 98, "prioriti": 98, "situat": 98, "lot": 98, "effort": 98, "go": 98, "turn": 98, "outsid": 98, "scope": 98, "better": 98, "fork": 98, "repo": 98, "dedic": 98, "branch": [98, 100], "repositori": 98, "own": [98, 100], "accept": 98, "autom": 98, "review": 98, "pep8": 98, "docstr": 98, "googl": 98, "hint": 98, "adher": 98, "pep": 98, "pylint": 98, "flake8": 98, "black": 98, "well": 98, "recommend": [98, 100], "mypi": 98, "pydocstyl": 98, "docformatt": 98, "commit": 98, "hook": 98, "instal": 98, "come": 98, "pip": [98, 100], "Then": 98, "everytim": 98, "pep257": 98, "concept": 98, "ljvmiranda921": 98, "io": 98, "notebook": 98, "2018": 98, "06": 98, "21": 98, "precommit": 98, "environ": 100, "virtual": 100, "anaconda": 100, "prove": 100, "instruct": 100, "setup": 100, "want": 100, "part": 100, "achiev": 100, "bash": 100, "shell": 100, "cvmf": 100, "opensciencegrid": 100, "py3": 100, "v4": 100, "sh": 100, "rhel_7_x86_64": 100, "metaproject": 100, "env": 100, "alia": 100, "script": 100, "With": 100, "now": 100, "light": 100, "extra": 100, "geometr": 100, "won": 100, "later": 100, "torch_cpu": 100, "txt": 100, "cpu": 100, "torch_gpu": 100, "prefer": 100, "unix": 100, "git": 100, "clone": 100, "usernam": 100, "cd": 100, "conda": 100, "gcc_linux": 100, "64": 100, "gxx_linux": 100, "libgcc": 100, "cudatoolkit": 100, "11": 100, "forg": 100, "torch_maco": 100, "On": 100, "maco": 100, "box": 100, "compil": 100, "gcc": 100, "date": 100, "possibli": 100, "cuda": 100, "toolkit": 100, "recent": 100, "omit": 100, "newer": 100, "export": 100, "ld_library_path": 100, "anaconda3": 100, "miniconda3": 100, "bashrc": 100, "librari": 100, "rm": 100, "asogaard": 100, "latest": 100, "dc423315742c": 100, "01_icetrai": 100, "01_convert_i3_fil": 100, "2023": 100, "01": 100, "24": 100, "41": 100, "27": 100, "graphnet_20230124": 100, "134127": 100, "46": 100, "convert_i3_fil": 100, "ic86": 100, "thread": 100, "00": 100, "79": 100, "42": 100, "26": 100, "413": 100, "88it": 100, "specialis": 100, "ones": 100, "push": 100, "vx": 100}, "objects": {"": [[1, 0, 0, "-", "graphnet"]], "graphnet": [[2, 0, 0, "-", "constants"], [3, 0, 0, "-", "data"], [41, 0, 0, "-", "deployment"], [45, 0, 0, "-", "models"], [74, 0, 0, "-", "pisa"], [77, 0, 0, "-", "training"], [83, 0, 0, "-", "utilities"]], "graphnet.data": [[4, 0, 0, "-", "constants"], [5, 0, 0, "-", "dataconverter"], [6, 0, 0, "-", "dataloader"], [7, 0, 0, "-", "dataset"], [14, 0, 0, "-", "extractors"], [31, 0, 0, "-", "parquet"], [33, 0, 0, "-", "pipeline"], [34, 0, 0, "-", "sqlite"], [37, 0, 0, "-", "utilities"]], "graphnet.data.constants": [[4, 1, 1, "", "FEATURES"], [4, 1, 1, "", "TRUTH"]], "graphnet.data.constants.FEATURES": [[4, 2, 1, "", "DEEPCORE"], [4, 2, 1, "", "ICECUBE86"], [4, 2, 1, "", "KAGGLE"], [4, 2, 1, "", "PROMETHEUS"], [4, 2, 1, "", "UPGRADE"]], "graphnet.data.constants.TRUTH": [[4, 2, 1, "", "DEEPCORE"], [4, 2, 1, "", "ICECUBE86"], [4, 2, 1, "", "KAGGLE"], [4, 2, 1, "", "PROMETHEUS"], [4, 2, 1, "", "UPGRADE"]], "graphnet.data.dataconverter": [[5, 1, 1, "", "DataConverter"], [5, 1, 1, "", "FileSet"], [5, 5, 1, "", "cache_output_files"], [5, 5, 1, "", "init_global_index"]], "graphnet.data.dataconverter.DataConverter": [[5, 3, 1, "", "execute"], [5, 4, 1, "", "file_suffix"], [5, 3, 1, "", "get_map_function"], [5, 3, 1, "", "merge_files"], [5, 3, 1, "", "save_data"]], "graphnet.data.dataconverter.FileSet": [[5, 2, 1, "", "gcd_file"], [5, 2, 1, "", "i3_file"]], "graphnet.data.dataloader": [[6, 1, 1, "", "DataLoader"], [6, 5, 1, "", "collate_fn"], [6, 5, 1, "", "do_shuffle"]], "graphnet.data.dataloader.DataLoader": [[6, 3, 1, "", "from_dataset_config"]], "graphnet.data.dataset": [[8, 0, 0, "-", "dataset"], [9, 0, 0, "-", "parquet"], [11, 0, 0, "-", "sqlite"]], "graphnet.data.dataset.dataset": [[8, 6, 1, "", "ColumnMissingException"], [8, 1, 1, "", "Dataset"], [8, 1, 1, "", "EnsembleDataset"], [8, 5, 1, "", "load_module"], [8, 5, 1, "", "parse_graph_definition"]], "graphnet.data.dataset.dataset.Dataset": [[8, 3, 1, "", "add_label"], [8, 3, 1, "", "concatenate"], [8, 3, 1, "", "from_config"], [8, 4, 1, "", "path"], [8, 3, 1, "", "query_table"], [8, 4, 1, "", "truth_table"]], "graphnet.data.dataset.parquet": [[10, 0, 0, "-", "parquet_dataset"]], "graphnet.data.dataset.parquet.parquet_dataset": [[10, 1, 1, "", "ParquetDataset"]], "graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset": [[10, 3, 1, "", "query_table"]], "graphnet.data.dataset.sqlite": [[12, 0, 0, "-", "sqlite_dataset"], [13, 0, 0, "-", "sqlite_dataset_perturbed"]], "graphnet.data.dataset.sqlite.sqlite_dataset": [[12, 1, 1, "", "SQLiteDataset"]], "graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset": [[12, 3, 1, "", "query_table"]], "graphnet.data.dataset.sqlite.sqlite_dataset_perturbed": [[13, 1, 1, "", "SQLiteDatasetPerturbed"]], "graphnet.data.extractors": [[15, 0, 0, "-", "i3extractor"], [16, 0, 0, "-", "i3featureextractor"], [17, 0, 0, "-", "i3genericextractor"], [18, 0, 0, "-", "i3hybridrecoextractor"], [19, 0, 0, "-", "i3ntmuonlabelsextractor"], [20, 0, 0, "-", "i3particleextractor"], [21, 0, 0, "-", "i3pisaextractor"], [22, 0, 0, "-", "i3quesoextractor"], [23, 0, 0, "-", "i3retroextractor"], [24, 0, 0, "-", "i3splinempeextractor"], [25, 0, 0, "-", "i3truthextractor"], [26, 0, 0, "-", "i3tumextractor"], [27, 0, 0, "-", "utilities"]], "graphnet.data.extractors.i3extractor": [[15, 1, 1, "", "I3Extractor"], [15, 1, 1, "", "I3ExtractorCollection"]], "graphnet.data.extractors.i3extractor.I3Extractor": [[15, 4, 1, "", "name"], [15, 3, 1, "", "set_files"]], "graphnet.data.extractors.i3extractor.I3ExtractorCollection": [[15, 3, 1, "", "set_files"]], "graphnet.data.extractors.i3featureextractor": [[16, 1, 1, "", "I3FeatureExtractor"], [16, 1, 1, "", "I3FeatureExtractorIceCube86"], [16, 1, 1, "", "I3FeatureExtractorIceCubeDeepCore"], [16, 1, 1, "", "I3FeatureExtractorIceCubeUpgrade"], [16, 1, 1, "", "I3PulseNoiseTruthFlagIceCubeUpgrade"]], "graphnet.data.extractors.i3genericextractor": [[17, 1, 1, "", "I3GenericExtractor"]], "graphnet.data.extractors.i3hybridrecoextractor": [[18, 1, 1, "", "I3GalacticPlaneHybridRecoExtractor"]], "graphnet.data.extractors.i3ntmuonlabelsextractor": [[19, 1, 1, "", "I3NTMuonLabelExtractor"]], "graphnet.data.extractors.i3particleextractor": [[20, 1, 1, "", "I3ParticleExtractor"]], "graphnet.data.extractors.i3pisaextractor": [[21, 1, 1, "", "I3PISAExtractor"]], "graphnet.data.extractors.i3quesoextractor": [[22, 1, 1, "", "I3QUESOExtractor"]], "graphnet.data.extractors.i3retroextractor": [[23, 1, 1, "", "I3RetroExtractor"]], "graphnet.data.extractors.i3splinempeextractor": [[24, 1, 1, "", "I3SplineMPEICExtractor"]], "graphnet.data.extractors.i3truthextractor": [[25, 1, 1, "", "I3TruthExtractor"]], "graphnet.data.extractors.i3tumextractor": [[26, 1, 1, "", "I3TUMExtractor"]], "graphnet.data.extractors.utilities": [[28, 0, 0, "-", "collections"], [29, 0, 0, "-", "frames"], [30, 0, 0, "-", "types"]], "graphnet.data.extractors.utilities.collections": [[28, 5, 1, "", "flatten_nested_dictionary"], [28, 5, 1, "", "serialise"], [28, 5, 1, "", "transpose_list_of_dicts"]], "graphnet.data.extractors.utilities.frames": [[29, 5, 1, "", "frame_is_montecarlo"], [29, 5, 1, "", "frame_is_noise"], [29, 5, 1, "", "get_om_keys_and_pulseseries"]], "graphnet.data.extractors.utilities.types": [[30, 5, 1, "", "break_cyclic_recursion"], [30, 5, 1, "", "cast_object_to_pure_python"], [30, 5, 1, "", "cast_pulse_series_to_pure_python"], [30, 5, 1, "", "get_member_variables"], [30, 5, 1, "", "is_boost_class"], [30, 5, 1, "", "is_boost_enum"], [30, 5, 1, "", "is_icecube_class"], [30, 5, 1, "", "is_method"], [30, 5, 1, "", "is_type"]], "graphnet.data.parquet": [[32, 0, 0, "-", "parquet_dataconverter"]], "graphnet.data.parquet.parquet_dataconverter": [[32, 1, 1, "", "ParquetDataConverter"]], "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter": [[32, 2, 1, "", "file_suffix"], [32, 3, 1, "", "merge_files"], [32, 3, 1, "", "save_data"]], "graphnet.data.pipeline": [[33, 1, 1, "", "InSQLitePipeline"]], "graphnet.data.sqlite": [[35, 0, 0, "-", "sqlite_dataconverter"], [36, 0, 0, "-", "sqlite_utilities"]], "graphnet.data.sqlite.sqlite_dataconverter": [[35, 1, 1, "", "SQLiteDataConverter"], [35, 5, 1, "", "construct_dataframe"], [35, 5, 1, "", "is_mc_tree"], [35, 5, 1, "", "is_pulse_map"]], "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter": [[35, 3, 1, "", "any_pulsemap_is_non_empty"], [35, 2, 1, "", "file_suffix"], [35, 3, 1, "", "merge_files"], [35, 3, 1, "", "save_data"]], "graphnet.data.sqlite.sqlite_utilities": [[36, 5, 1, "", "attach_index"], [36, 5, 1, "", "create_table"], [36, 5, 1, "", "create_table_and_save_to_sql"], [36, 5, 1, "", "database_exists"], [36, 5, 1, "", "database_table_exists"], [36, 5, 1, "", "run_sql_code"], [36, 5, 1, "", "save_to_sql"]], "graphnet.data.utilities": [[38, 0, 0, "-", "parquet_to_sqlite"], [39, 0, 0, "-", "random"], [40, 0, 0, "-", "string_selection_resolver"]], "graphnet.data.utilities.parquet_to_sqlite": [[38, 1, 1, "", "ParquetToSQLiteConverter"]], "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter": [[38, 3, 1, "", "run"]], "graphnet.data.utilities.random": [[39, 5, 1, "", "pairwise_shuffle"]], "graphnet.data.utilities.string_selection_resolver": [[40, 1, 1, "", "StringSelectionResolver"]], "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver": [[40, 3, 1, "", "resolve"]], "graphnet.deployment.i3modules": [[44, 0, 0, "-", "graphnet_module"]], "graphnet.deployment.i3modules.graphnet_module": [[44, 1, 1, "", "GraphNeTI3Module"], [44, 1, 1, "", "I3InferenceModule"], [44, 1, 1, "", "I3PulseCleanerModule"]], "graphnet.models": [[46, 0, 0, "-", "coarsening"], [47, 0, 0, "-", "components"], [50, 0, 0, "-", "detector"], [54, 0, 0, "-", "gnn"], [60, 0, 0, "-", "graphs"], [67, 0, 0, "-", "model"], [68, 0, 0, "-", "standard_model"], [69, 0, 0, "-", "task"], [73, 0, 0, "-", "utils"]], "graphnet.models.coarsening": [[46, 1, 1, "", "AttributeCoarsening"], [46, 1, 1, "", "Coarsening"], [46, 1, 1, "", "CustomDOMCoarsening"], [46, 1, 1, "", "DOMAndTimeWindowCoarsening"], [46, 1, 1, "", "DOMCoarsening"], [46, 5, 1, "", "unbatch_edge_index"]], "graphnet.models.coarsening.Coarsening": [[46, 3, 1, "", "forward"], [46, 2, 1, "", "reduce_options"]], "graphnet.models.components": [[48, 0, 0, "-", "layers"], [49, 0, 0, "-", "pool"]], "graphnet.models.components.layers": [[48, 1, 1, "", "DynEdgeConv"], [48, 1, 1, "", "DynTrans"], [48, 1, 1, "", "EdgeConvTito"]], "graphnet.models.components.layers.DynEdgeConv": [[48, 3, 1, "", "forward"]], "graphnet.models.components.layers.DynTrans": [[48, 3, 1, "", "forward"]], "graphnet.models.components.layers.EdgeConvTito": [[48, 3, 1, "", "forward"], [48, 3, 1, "", "message"], [48, 3, 1, "", "reset_parameters"]], "graphnet.models.components.pool": [[49, 5, 1, "", "group_by"], [49, 5, 1, "", "group_pulses_to_dom"], [49, 5, 1, "", "group_pulses_to_pmt"], [49, 5, 1, "", "min_pool"], [49, 5, 1, "", "min_pool_x"], [49, 5, 1, "", "std_pool"], [49, 5, 1, "", "std_pool_x"], [49, 5, 1, "", "sum_pool"], [49, 5, 1, "", "sum_pool_and_distribute"], [49, 5, 1, "", "sum_pool_x"]], "graphnet.models.detector": [[51, 0, 0, "-", "detector"], [52, 0, 0, "-", "icecube"], [53, 0, 0, "-", "prometheus"]], "graphnet.models.detector.detector": [[51, 1, 1, "", "Detector"]], "graphnet.models.detector.detector.Detector": [[51, 3, 1, "", "feature_map"], [51, 3, 1, "", "forward"]], "graphnet.models.detector.icecube": [[52, 1, 1, "", "IceCube86"], [52, 1, 1, "", "IceCubeDeepCore"], [52, 1, 1, "", "IceCubeKaggle"], [52, 1, 1, "", "IceCubeUpgrade"]], "graphnet.models.detector.icecube.IceCube86": [[52, 3, 1, "", "feature_map"]], "graphnet.models.detector.icecube.IceCubeDeepCore": [[52, 3, 1, "", "feature_map"]], "graphnet.models.detector.icecube.IceCubeKaggle": [[52, 3, 1, "", "feature_map"]], "graphnet.models.detector.icecube.IceCubeUpgrade": [[52, 3, 1, "", "feature_map"]], "graphnet.models.detector.prometheus": [[53, 1, 1, "", "Prometheus"]], "graphnet.models.detector.prometheus.Prometheus": [[53, 3, 1, "", "feature_map"]], "graphnet.models.gnn": [[55, 0, 0, "-", "convnet"], [56, 0, 0, "-", "dynedge"], [57, 0, 0, "-", "dynedge_jinst"], [58, 0, 0, "-", "dynedge_kaggle_tito"], [59, 0, 0, "-", "gnn"]], "graphnet.models.gnn.convnet": [[55, 1, 1, "", "ConvNet"]], "graphnet.models.gnn.convnet.ConvNet": [[55, 3, 1, "", "forward"]], "graphnet.models.gnn.dynedge": [[56, 1, 1, "", "DynEdge"]], "graphnet.models.gnn.dynedge.DynEdge": [[56, 3, 1, "", "forward"]], "graphnet.models.gnn.dynedge_jinst": [[57, 1, 1, "", "DynEdgeJINST"]], "graphnet.models.gnn.dynedge_jinst.DynEdgeJINST": [[57, 3, 1, "", "forward"]], "graphnet.models.gnn.dynedge_kaggle_tito": [[58, 1, 1, "", "DynEdgeTITO"]], "graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO": [[58, 3, 1, "", "forward"]], "graphnet.models.gnn.gnn": [[59, 1, 1, "", "GNN"]], "graphnet.models.gnn.gnn.GNN": [[59, 3, 1, "", "forward"], [59, 4, 1, "", "nb_inputs"], [59, 4, 1, "", "nb_outputs"]], "graphnet.models.graphs": [[61, 0, 0, "-", "edges"], [63, 0, 0, "-", "graph_definition"], [64, 0, 0, "-", "graphs"], [65, 0, 0, "-", "nodes"]], "graphnet.models.graphs.edges": [[62, 0, 0, "-", "edges"]], "graphnet.models.graphs.edges.edges": [[62, 1, 1, "", "EdgeDefinition"], [62, 1, 1, "", "EuclideanEdges"], [62, 1, 1, "", "KNNEdges"], [62, 1, 1, "", "RadialEdges"]], "graphnet.models.graphs.edges.edges.EdgeDefinition": [[62, 3, 1, "", "forward"]], "graphnet.models.graphs.graph_definition": [[63, 1, 1, "", "GraphDefinition"]], "graphnet.models.graphs.graph_definition.GraphDefinition": [[63, 3, 1, "", "forward"]], "graphnet.models.graphs.graphs": [[64, 1, 1, "", "KNNGraph"]], "graphnet.models.graphs.nodes": [[66, 0, 0, "-", "nodes"]], "graphnet.models.graphs.nodes.nodes": [[66, 1, 1, "", "NodeDefinition"], [66, 1, 1, "", "NodesAsPulses"]], "graphnet.models.graphs.nodes.nodes.NodeDefinition": [[66, 3, 1, "", "forward"], [66, 4, 1, "", "nb_outputs"], [66, 3, 1, "", "set_number_of_inputs"]], "graphnet.models.model": [[67, 1, 1, "", "Model"]], "graphnet.models.model.Model": [[67, 3, 1, "", "fit"], [67, 3, 1, "", "forward"], [67, 3, 1, "", "from_config"], [67, 3, 1, "", "load"], [67, 3, 1, "", "load_state_dict"], [67, 3, 1, "", "predict"], [67, 3, 1, "", "predict_as_dataframe"], [67, 3, 1, "", "save"], [67, 3, 1, "", "save_state_dict"]], "graphnet.models.standard_model": [[68, 1, 1, "", "StandardModel"]], "graphnet.models.standard_model.StandardModel": [[68, 3, 1, "", "compute_loss"], [68, 3, 1, "", "configure_optimizers"], [68, 3, 1, "", "forward"], [68, 3, 1, "", "inference"], [68, 3, 1, "", "predict"], [68, 3, 1, "", "predict_as_dataframe"], [68, 4, 1, "", "prediction_labels"], [68, 3, 1, "", "shared_step"], [68, 4, 1, "", "target_labels"], [68, 3, 1, "", "train"], [68, 3, 1, "", "training_step"], [68, 3, 1, "", "validation_step"]], "graphnet.models.task": [[70, 0, 0, "-", "classification"], [71, 0, 0, "-", "reconstruction"], [72, 0, 0, "-", "task"]], "graphnet.models.task.classification": [[70, 1, 1, "", "BinaryClassificationTask"], [70, 1, 1, "", "BinaryClassificationTaskLogits"], [70, 1, 1, "", "MulticlassClassificationTask"]], "graphnet.models.task.classification.BinaryClassificationTask": [[70, 2, 1, "", "default_prediction_labels"], [70, 2, 1, "", "default_target_labels"], [70, 2, 1, "", "nb_inputs"]], "graphnet.models.task.classification.BinaryClassificationTaskLogits": [[70, 2, 1, "", "default_prediction_labels"], [70, 2, 1, "", "default_target_labels"], [70, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction": [[71, 1, 1, "", "AzimuthReconstruction"], [71, 1, 1, "", "AzimuthReconstructionWithKappa"], [71, 1, 1, "", "DirectionReconstructionWithKappa"], [71, 1, 1, "", "EnergyReconstruction"], [71, 1, 1, "", "EnergyReconstructionWithPower"], [71, 1, 1, "", "EnergyReconstructionWithUncertainty"], [71, 1, 1, "", "InelasticityReconstruction"], [71, 1, 1, "", "PositionReconstruction"], [71, 1, 1, "", "TimeReconstruction"], [71, 1, 1, "", "VertexReconstruction"], [71, 1, 1, "", "ZenithReconstruction"], [71, 1, 1, "", "ZenithReconstructionWithKappa"]], "graphnet.models.task.reconstruction.AzimuthReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.EnergyReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.EnergyReconstructionWithPower": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.InelasticityReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.PositionReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.TimeReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.VertexReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.ZenithReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.task": [[72, 1, 1, "", "IdentityTask"], [72, 1, 1, "", "Task"]], "graphnet.models.task.task.IdentityTask": [[72, 4, 1, "", "default_prediction_labels"], [72, 4, 1, "", "default_target_labels"], [72, 4, 1, "", "nb_inputs"]], "graphnet.models.task.task.Task": [[72, 3, 1, "", "compute_loss"], [72, 4, 1, "", "default_prediction_labels"], [72, 4, 1, "", "default_target_labels"], [72, 3, 1, "", "forward"], [72, 3, 1, "", "inference"], [72, 4, 1, "", "nb_inputs"], [72, 3, 1, "", "train_eval"]], "graphnet.models.utils": [[73, 5, 1, "", "calculate_distance_matrix"], [73, 5, 1, "", "calculate_xyzt_homophily"], [73, 5, 1, "", "knn_graph_batch"]], "graphnet.pisa": [[75, 0, 0, "-", "fitting"], [76, 0, 0, "-", "plotting"]], "graphnet.pisa.fitting": [[75, 1, 1, "", "ContourFitter"], [75, 1, 1, "", "WeightFitter"], [75, 5, 1, "", "config_updater"]], "graphnet.pisa.fitting.ContourFitter": [[75, 3, 1, "", "fit_1d_contour"], [75, 3, 1, "", "fit_2d_contour"]], "graphnet.pisa.fitting.WeightFitter": [[75, 3, 1, "", "fit_weights"]], "graphnet.pisa.plotting": [[76, 5, 1, "", "plot_1D_contour"], [76, 5, 1, "", "plot_2D_contour"], [76, 5, 1, "", "read_entry"]], "graphnet.training": [[78, 0, 0, "-", "callbacks"], [79, 0, 0, "-", "labels"], [80, 0, 0, "-", "loss_functions"], [81, 0, 0, "-", "utils"], [82, 0, 0, "-", "weight_fitting"]], "graphnet.training.callbacks": [[78, 1, 1, "", "PiecewiseLinearLR"], [78, 1, 1, "", "ProgressBar"]], "graphnet.training.callbacks.PiecewiseLinearLR": [[78, 3, 1, "", "get_lr"]], "graphnet.training.callbacks.ProgressBar": [[78, 3, 1, "", "get_metrics"], [78, 3, 1, "", "init_predict_tqdm"], [78, 3, 1, "", "init_test_tqdm"], [78, 3, 1, "", "init_train_tqdm"], [78, 3, 1, "", "init_validation_tqdm"], [78, 3, 1, "", "on_train_epoch_end"], [78, 3, 1, "", "on_train_epoch_start"]], "graphnet.training.labels": [[79, 1, 1, "", "Direction"], [79, 1, 1, "", "Label"]], "graphnet.training.labels.Label": [[79, 4, 1, "", "key"]], "graphnet.training.loss_functions": [[80, 1, 1, "", "BinaryCrossEntropyLoss"], [80, 1, 1, "", "CrossEntropyLoss"], [80, 1, 1, "", "EuclideanDistanceLoss"], [80, 1, 1, "", "LogCMK"], [80, 1, 1, "", "LogCoshLoss"], [80, 1, 1, "", "LossFunction"], [80, 1, 1, "", "MSELoss"], [80, 1, 1, "", "RMSELoss"], [80, 1, 1, "", "VonMisesFisher2DLoss"], [80, 1, 1, "", "VonMisesFisher3DLoss"], [80, 1, 1, "", "VonMisesFisherLoss"]], "graphnet.training.loss_functions.LogCMK": [[80, 3, 1, "", "backward"], [80, 3, 1, "", "forward"]], "graphnet.training.loss_functions.LossFunction": [[80, 3, 1, "", "forward"]], "graphnet.training.loss_functions.VonMisesFisherLoss": [[80, 3, 1, "", "log_cmk"], [80, 3, 1, "", "log_cmk_approx"], [80, 3, 1, "", "log_cmk_exact"]], "graphnet.training.utils": [[81, 5, 1, "", "collate_fn"], [81, 5, 1, "", "get_predictions"], [81, 5, 1, "", "make_dataloader"], [81, 5, 1, "", "make_train_validation_dataloader"], [81, 5, 1, "", "save_results"]], "graphnet.training.weight_fitting": [[82, 1, 1, "", "BjoernLow"], [82, 1, 1, "", "Uniform"], [82, 1, 1, "", "WeightFitter"]], "graphnet.training.weight_fitting.WeightFitter": [[82, 3, 1, "", "fit"]], "graphnet.utilities": [[84, 0, 0, "-", "argparse"], [85, 0, 0, "-", "config"], [92, 0, 0, "-", "decorators"], [93, 0, 0, "-", "filesys"], [94, 0, 0, "-", "imports"], [95, 0, 0, "-", "logging"], [96, 0, 0, "-", "maths"]], "graphnet.utilities.argparse": [[84, 1, 1, "", "ArgumentParser"], [84, 1, 1, "", "Options"]], "graphnet.utilities.argparse.ArgumentParser": [[84, 2, 1, "", "standard_arguments"], [84, 3, 1, "", "with_standard_arguments"]], "graphnet.utilities.argparse.Options": [[84, 3, 1, "", "contains"], [84, 3, 1, "", "pop_default"]], "graphnet.utilities.config": [[86, 0, 0, "-", "base_config"], [87, 0, 0, "-", "configurable"], [88, 0, 0, "-", "dataset_config"], [89, 0, 0, "-", "model_config"], [90, 0, 0, "-", "parsing"], [91, 0, 0, "-", "training_config"]], "graphnet.utilities.config.base_config": [[86, 1, 1, "", "BaseConfig"], [86, 5, 1, "", "get_all_argument_values"]], "graphnet.utilities.config.base_config.BaseConfig": [[86, 3, 1, "", "as_dict"], [86, 3, 1, "", "dump"], [86, 3, 1, "", "load"], [86, 2, 1, "", "model_config"], [86, 2, 1, "", "model_fields"]], "graphnet.utilities.config.configurable": [[87, 1, 1, "", "Configurable"]], "graphnet.utilities.config.configurable.Configurable": [[87, 4, 1, "", "config"], [87, 3, 1, "", "from_config"], [87, 3, 1, "", "save_config"]], "graphnet.utilities.config.dataset_config": [[88, 1, 1, "", "DatasetConfig"], [88, 5, 1, "", "save_dataset_config"]], "graphnet.utilities.config.dataset_config.DatasetConfig": [[88, 3, 1, "", "as_dict"], [88, 2, 1, "", "features"], [88, 2, 1, "", "graph_definition"], [88, 2, 1, "", "index_column"], [88, 2, 1, "", "loss_weight_column"], [88, 2, 1, "", "loss_weight_default_value"], [88, 2, 1, "", "loss_weight_table"], [88, 2, 1, "", "model_config"], [88, 2, 1, "", "model_fields"], [88, 2, 1, "", "node_truth"], [88, 2, 1, "", "node_truth_table"], [88, 2, 1, "", "path"], [88, 2, 1, "", "pulsemaps"], [88, 2, 1, "", "seed"], [88, 2, 1, "", "selection"], [88, 2, 1, "", "string_selection"], [88, 2, 1, "", "truth"], [88, 2, 1, "", "truth_table"]], "graphnet.utilities.config.model_config": [[89, 1, 1, "", "ModelConfig"], [89, 5, 1, "", "save_model_config"]], "graphnet.utilities.config.model_config.ModelConfig": [[89, 2, 1, "", "arguments"], [89, 3, 1, "", "as_dict"], [89, 2, 1, "", "class_name"], [89, 2, 1, "", "model_config"], [89, 2, 1, "", "model_fields"]], "graphnet.utilities.config.parsing": [[90, 5, 1, "", "get_all_grapnet_classes"], [90, 5, 1, "", "get_graphnet_classes"], [90, 5, 1, "", "is_graphnet_class"], [90, 5, 1, "", "is_graphnet_module"], [90, 5, 1, "", "list_all_submodules"], [90, 5, 1, "", "traverse_and_apply"]], "graphnet.utilities.config.training_config": [[91, 1, 1, "", "TrainingConfig"]], "graphnet.utilities.config.training_config.TrainingConfig": [[91, 2, 1, "", "dataloader"], [91, 2, 1, "", "early_stopping_patience"], [91, 2, 1, "", "fit"], [91, 2, 1, "", "model_config"], [91, 2, 1, "", "model_fields"], [91, 2, 1, "", "target"]], "graphnet.utilities.filesys": [[93, 5, 1, "", "find_i3_files"], [93, 5, 1, "", "has_extension"], [93, 5, 1, "", "is_gcd_file"], [93, 5, 1, "", "is_i3_file"]], "graphnet.utilities.imports": [[94, 5, 1, "", "has_icecube_package"], [94, 5, 1, "", "has_pisa_package"], [94, 5, 1, "", "has_torch_package"], [94, 5, 1, "", "requires_icecube"]], "graphnet.utilities.logging": [[95, 1, 1, "", "Logger"], [95, 1, 1, "", "RepeatFilter"]], "graphnet.utilities.logging.Logger": [[95, 3, 1, "", "critical"], [95, 3, 1, "", "debug"], [95, 3, 1, "", "error"], [95, 4, 1, "", "file_handlers"], [95, 4, 1, "", "handlers"], [95, 3, 1, "", "info"], [95, 3, 1, "", "setLevel"], [95, 4, 1, "", "stream_handlers"], [95, 3, 1, "", "warning"], [95, 3, 1, "", "warning_once"]], "graphnet.utilities.logging.RepeatFilter": [[95, 3, 1, "", "filter"], [95, 2, 1, "", "nb_repeats_allowed"]], "graphnet.utilities.maths": [[96, 5, 1, "", "eps_like"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:attribute", "3": "py:method", "4": "py:property", "5": "py:function", "6": "py:exception"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "attribute", "Python attribute"], "3": ["py", "method", "Python method"], "4": ["py", "property", "Python property"], "5": ["py", "function", "Python function"], "6": ["py", "exception", "Python exception"]}, "titleterms": {"about": [0, 99], "impact": [0, 99], "usag": [0, 99], "acknowledg": [0, 99], "api": 1, "constant": [2, 4], "data": 3, "dataconvert": 5, "dataload": 6, "dataset": [7, 8], "parquet": [9, 31], "parquet_dataset": 10, "sqlite": [11, 34], "sqlite_dataset": 12, "sqlite_dataset_perturb": 13, "extractor": 14, "i3extractor": 15, "i3featureextractor": 16, "i3genericextractor": 17, "i3hybridrecoextractor": 18, "i3ntmuonlabelsextractor": 19, "i3particleextractor": 20, "i3pisaextractor": 21, "i3quesoextractor": 22, "i3retroextractor": 23, "i3splinempeextractor": 24, "i3truthextractor": 25, "i3tumextractor": 26, "util": [27, 37, 73, 81, 83], "collect": 28, "frame": 29, "type": 30, "parquet_dataconvert": 32, "pipelin": 33, "sqlite_dataconvert": 35, "sqlite_util": 36, "parquet_to_sqlit": 38, "random": 39, "string_selection_resolv": 40, "deploy": [41, 43], "i3modul": 42, "graphnet_modul": 44, "model": [45, 67], "coarsen": 46, "compon": 47, "layer": 48, "pool": 49, "detector": [50, 51], "icecub": 52, "prometheu": 53, "gnn": [54, 59], "convnet": 55, "dynedg": 56, "dynedge_jinst": 57, "dynedge_kaggle_tito": 58, "graph": [60, 64], "edg": [61, 62], "graph_definit": 63, "node": [65, 66], "standard_model": 68, "task": [69, 72], "classif": 70, "reconstruct": 71, "pisa": 74, "fit": 75, "plot": 76, "train": 77, "callback": 78, "label": 79, "loss_funct": 80, "weight_fit": 82, "argpars": 84, "config": 85, "base_config": 86, "configur": 87, "dataset_config": 88, "model_config": 89, "pars": 90, "training_config": 91, "decor": 92, "filesi": 93, "import": 94, "log": 95, "math": 96, "src": 97, "contribut": 98, "github": 98, "issu": 98, "pull": 98, "request": 98, "convent": 98, "code": 98, "qualiti": 98, "instal": 100, "icetrai": 100, "stand": 100, "alon": 100, "run": 100, "docker": 100}, "envversion": {"sphinx.domains.c": 3, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 9, "sphinx.domains.index": 1, "sphinx.domains.javascript": 3, "sphinx.domains.math": 2, "sphinx.domains.python": 4, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx.ext.todo": 2, "sphinx.ext.viewcode": 1, "sphinx": 60}, "alltitles": {"About": [[0, "about"], [99, "about"]], "Impact": [[0, "impact"], [99, "impact"]], "Usage": [[0, "usage"], [99, "usage"]], "Acknowledgements": [[0, "acknowledgements"], [99, "acknowledgements"]], "API": [[1, "module-graphnet"]], "constants": [[2, "module-graphnet.constants"], [4, "module-graphnet.data.constants"]], "data": [[3, "module-graphnet.data"]], "dataconverter": [[5, "module-graphnet.data.dataconverter"]], "dataloader": [[6, "module-graphnet.data.dataloader"]], "dataset": [[7, "module-graphnet.data.dataset"], [8, "module-graphnet.data.dataset.dataset"]], "parquet": [[9, "module-graphnet.data.dataset.parquet"], [31, "module-graphnet.data.parquet"]], "parquet_dataset": [[10, "module-graphnet.data.dataset.parquet.parquet_dataset"]], "sqlite": [[11, "module-graphnet.data.dataset.sqlite"], [34, "module-graphnet.data.sqlite"]], "sqlite_dataset": [[12, "module-graphnet.data.dataset.sqlite.sqlite_dataset"]], "sqlite_dataset_perturbed": [[13, "module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed"]], "extractors": [[14, "module-graphnet.data.extractors"]], "i3extractor": [[15, "module-graphnet.data.extractors.i3extractor"]], "i3featureextractor": [[16, "module-graphnet.data.extractors.i3featureextractor"]], "i3genericextractor": [[17, "module-graphnet.data.extractors.i3genericextractor"]], "i3hybridrecoextractor": [[18, "module-graphnet.data.extractors.i3hybridrecoextractor"]], "i3ntmuonlabelsextractor": [[19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"]], "i3particleextractor": [[20, "module-graphnet.data.extractors.i3particleextractor"]], "i3pisaextractor": [[21, "module-graphnet.data.extractors.i3pisaextractor"]], "i3quesoextractor": [[22, "module-graphnet.data.extractors.i3quesoextractor"]], "i3retroextractor": [[23, "module-graphnet.data.extractors.i3retroextractor"]], "i3splinempeextractor": [[24, "module-graphnet.data.extractors.i3splinempeextractor"]], "i3truthextractor": [[25, "module-graphnet.data.extractors.i3truthextractor"]], "i3tumextractor": [[26, "module-graphnet.data.extractors.i3tumextractor"]], "utilities": [[27, "module-graphnet.data.extractors.utilities"], [37, "module-graphnet.data.utilities"], [83, "module-graphnet.utilities"]], "collections": [[28, "module-graphnet.data.extractors.utilities.collections"]], "frames": [[29, "module-graphnet.data.extractors.utilities.frames"]], "types": [[30, "module-graphnet.data.extractors.utilities.types"]], "parquet_dataconverter": [[32, "module-graphnet.data.parquet.parquet_dataconverter"]], "pipeline": [[33, "module-graphnet.data.pipeline"]], "sqlite_dataconverter": [[35, "module-graphnet.data.sqlite.sqlite_dataconverter"]], "sqlite_utilities": [[36, "module-graphnet.data.sqlite.sqlite_utilities"]], "parquet_to_sqlite": [[38, "module-graphnet.data.utilities.parquet_to_sqlite"]], "random": [[39, "module-graphnet.data.utilities.random"]], "string_selection_resolver": [[40, "module-graphnet.data.utilities.string_selection_resolver"]], "deployment": [[41, "module-graphnet.deployment"]], "i3modules": [[42, "i3modules"]], "deployer": [[43, "deployer"]], "graphnet_module": [[44, "module-graphnet.deployment.i3modules.graphnet_module"]], "models": [[45, "module-graphnet.models"]], "coarsening": [[46, "module-graphnet.models.coarsening"]], "components": [[47, "module-graphnet.models.components"]], "layers": [[48, "module-graphnet.models.components.layers"]], "pool": [[49, "module-graphnet.models.components.pool"]], "detector": [[50, "module-graphnet.models.detector"], [51, "module-graphnet.models.detector.detector"]], "icecube": [[52, "module-graphnet.models.detector.icecube"]], "prometheus": [[53, "module-graphnet.models.detector.prometheus"]], "gnn": [[54, "module-graphnet.models.gnn"], [59, "module-graphnet.models.gnn.gnn"]], "convnet": [[55, "module-graphnet.models.gnn.convnet"]], "dynedge": [[56, "module-graphnet.models.gnn.dynedge"]], "dynedge_jinst": [[57, "module-graphnet.models.gnn.dynedge_jinst"]], "dynedge_kaggle_tito": [[58, "module-graphnet.models.gnn.dynedge_kaggle_tito"]], "graphs": [[60, "module-graphnet.models.graphs"], [64, "module-graphnet.models.graphs.graphs"]], "edges": [[61, "module-graphnet.models.graphs.edges"], [62, "module-graphnet.models.graphs.edges.edges"]], "graph_definition": [[63, "module-graphnet.models.graphs.graph_definition"]], "nodes": [[65, "module-graphnet.models.graphs.nodes"], [66, "module-graphnet.models.graphs.nodes.nodes"]], "model": [[67, "module-graphnet.models.model"]], "standard_model": [[68, "module-graphnet.models.standard_model"]], "task": [[69, "module-graphnet.models.task"], [72, "module-graphnet.models.task.task"]], "classification": [[70, "module-graphnet.models.task.classification"]], "reconstruction": [[71, "module-graphnet.models.task.reconstruction"]], "utils": [[73, "module-graphnet.models.utils"], [81, "module-graphnet.training.utils"]], "pisa": [[74, "module-graphnet.pisa"]], "fitting": [[75, "module-graphnet.pisa.fitting"]], "plotting": [[76, "module-graphnet.pisa.plotting"]], "training": [[77, "module-graphnet.training"]], "callbacks": [[78, "module-graphnet.training.callbacks"]], "labels": [[79, "module-graphnet.training.labels"]], "loss_functions": [[80, "module-graphnet.training.loss_functions"]], "weight_fitting": [[82, "module-graphnet.training.weight_fitting"]], "argparse": [[84, "module-graphnet.utilities.argparse"]], "config": [[85, "module-graphnet.utilities.config"]], "base_config": [[86, "module-graphnet.utilities.config.base_config"]], "configurable": [[87, "module-graphnet.utilities.config.configurable"]], "dataset_config": [[88, "module-graphnet.utilities.config.dataset_config"]], "model_config": [[89, "module-graphnet.utilities.config.model_config"]], "parsing": [[90, "module-graphnet.utilities.config.parsing"]], "training_config": [[91, "module-graphnet.utilities.config.training_config"]], "decorators": [[92, "module-graphnet.utilities.decorators"]], "filesys": [[93, "module-graphnet.utilities.filesys"]], "imports": [[94, "module-graphnet.utilities.imports"]], "logging": [[95, "module-graphnet.utilities.logging"]], "maths": [[96, "module-graphnet.utilities.maths"]], "src": [[97, "src"]], "Contribute": [[98, "contribute"]], "GitHub issues": [[98, "github-issues"]], "Pull requests": [[98, "pull-requests"]], "Conventions": [[98, "conventions"]], "Code quality": [[98, "code-quality"]], "Install": [[100, "install"]], "Installing with IceTray": [[100, "installing-with-icetray"]], "Installing stand-alone": [[100, "installing-stand-alone"]], "Running in Docker": [[100, "running-in-docker"]]}, "indexentries": {"graphnet": [[1, "module-graphnet"]], "module": [[1, "module-graphnet"], [2, "module-graphnet.constants"], [3, "module-graphnet.data"], [4, "module-graphnet.data.constants"], [5, "module-graphnet.data.dataconverter"], [6, "module-graphnet.data.dataloader"], [7, "module-graphnet.data.dataset"], [8, "module-graphnet.data.dataset.dataset"], [9, "module-graphnet.data.dataset.parquet"], [10, "module-graphnet.data.dataset.parquet.parquet_dataset"], [11, "module-graphnet.data.dataset.sqlite"], [12, "module-graphnet.data.dataset.sqlite.sqlite_dataset"], [13, "module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed"], [14, "module-graphnet.data.extractors"], [15, "module-graphnet.data.extractors.i3extractor"], [16, "module-graphnet.data.extractors.i3featureextractor"], [17, "module-graphnet.data.extractors.i3genericextractor"], [18, "module-graphnet.data.extractors.i3hybridrecoextractor"], [19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"], [20, "module-graphnet.data.extractors.i3particleextractor"], [21, "module-graphnet.data.extractors.i3pisaextractor"], [22, "module-graphnet.data.extractors.i3quesoextractor"], [23, "module-graphnet.data.extractors.i3retroextractor"], [24, "module-graphnet.data.extractors.i3splinempeextractor"], [25, "module-graphnet.data.extractors.i3truthextractor"], [26, "module-graphnet.data.extractors.i3tumextractor"], [27, "module-graphnet.data.extractors.utilities"], [28, "module-graphnet.data.extractors.utilities.collections"], [29, "module-graphnet.data.extractors.utilities.frames"], [30, "module-graphnet.data.extractors.utilities.types"], [31, "module-graphnet.data.parquet"], [32, "module-graphnet.data.parquet.parquet_dataconverter"], [33, "module-graphnet.data.pipeline"], [34, "module-graphnet.data.sqlite"], [35, "module-graphnet.data.sqlite.sqlite_dataconverter"], [36, "module-graphnet.data.sqlite.sqlite_utilities"], [37, "module-graphnet.data.utilities"], [38, "module-graphnet.data.utilities.parquet_to_sqlite"], [39, "module-graphnet.data.utilities.random"], [40, "module-graphnet.data.utilities.string_selection_resolver"], [41, "module-graphnet.deployment"], [44, "module-graphnet.deployment.i3modules.graphnet_module"], [45, "module-graphnet.models"], [46, "module-graphnet.models.coarsening"], [47, "module-graphnet.models.components"], [48, "module-graphnet.models.components.layers"], [49, "module-graphnet.models.components.pool"], [50, "module-graphnet.models.detector"], [51, "module-graphnet.models.detector.detector"], [52, "module-graphnet.models.detector.icecube"], [53, "module-graphnet.models.detector.prometheus"], [54, "module-graphnet.models.gnn"], [55, "module-graphnet.models.gnn.convnet"], [56, "module-graphnet.models.gnn.dynedge"], [57, "module-graphnet.models.gnn.dynedge_jinst"], [58, "module-graphnet.models.gnn.dynedge_kaggle_tito"], [59, "module-graphnet.models.gnn.gnn"], [60, "module-graphnet.models.graphs"], [61, "module-graphnet.models.graphs.edges"], [62, "module-graphnet.models.graphs.edges.edges"], [63, "module-graphnet.models.graphs.graph_definition"], [64, "module-graphnet.models.graphs.graphs"], [65, "module-graphnet.models.graphs.nodes"], [66, "module-graphnet.models.graphs.nodes.nodes"], [67, "module-graphnet.models.model"], [68, "module-graphnet.models.standard_model"], [69, "module-graphnet.models.task"], [70, "module-graphnet.models.task.classification"], [71, "module-graphnet.models.task.reconstruction"], [72, "module-graphnet.models.task.task"], [73, "module-graphnet.models.utils"], [74, "module-graphnet.pisa"], [75, "module-graphnet.pisa.fitting"], [76, "module-graphnet.pisa.plotting"], [77, "module-graphnet.training"], [78, "module-graphnet.training.callbacks"], [79, "module-graphnet.training.labels"], [80, "module-graphnet.training.loss_functions"], [81, "module-graphnet.training.utils"], [82, "module-graphnet.training.weight_fitting"], [83, "module-graphnet.utilities"], [84, "module-graphnet.utilities.argparse"], [85, "module-graphnet.utilities.config"], [86, "module-graphnet.utilities.config.base_config"], [87, "module-graphnet.utilities.config.configurable"], [88, "module-graphnet.utilities.config.dataset_config"], [89, "module-graphnet.utilities.config.model_config"], [90, "module-graphnet.utilities.config.parsing"], [91, "module-graphnet.utilities.config.training_config"], [92, "module-graphnet.utilities.decorators"], [93, "module-graphnet.utilities.filesys"], [94, "module-graphnet.utilities.imports"], [95, "module-graphnet.utilities.logging"], [96, "module-graphnet.utilities.maths"]], "graphnet.constants": [[2, "module-graphnet.constants"]], "graphnet.data": [[3, "module-graphnet.data"]], "deepcore (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.DEEPCORE"]], "deepcore (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.DEEPCORE"]], "features (class in graphnet.data.constants)": [[4, "graphnet.data.constants.FEATURES"]], "icecube86 (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.ICECUBE86"]], "icecube86 (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.ICECUBE86"]], "kaggle (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.KAGGLE"]], "kaggle (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.KAGGLE"]], "prometheus (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.PROMETHEUS"]], "prometheus (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.PROMETHEUS"]], "truth (class in graphnet.data.constants)": [[4, "graphnet.data.constants.TRUTH"]], "upgrade (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.UPGRADE"]], "upgrade (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.UPGRADE"]], "graphnet.data.constants": [[4, "module-graphnet.data.constants"]], "dataconverter (class in graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.DataConverter"]], "fileset (class in graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.FileSet"]], "cache_output_files() (in module graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.cache_output_files"]], "execute() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.execute"]], "file_suffix (graphnet.data.dataconverter.dataconverter property)": [[5, "graphnet.data.dataconverter.DataConverter.file_suffix"]], "gcd_file (graphnet.data.dataconverter.fileset attribute)": [[5, "graphnet.data.dataconverter.FileSet.gcd_file"]], "get_map_function() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.get_map_function"]], "graphnet.data.dataconverter": [[5, "module-graphnet.data.dataconverter"]], "i3_file (graphnet.data.dataconverter.fileset attribute)": [[5, "graphnet.data.dataconverter.FileSet.i3_file"]], "init_global_index() (in module graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.init_global_index"]], "merge_files() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.merge_files"]], "save_data() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.save_data"]], "dataloader (class in graphnet.data.dataloader)": [[6, "graphnet.data.dataloader.DataLoader"]], "collate_fn() (in module graphnet.data.dataloader)": [[6, "graphnet.data.dataloader.collate_fn"]], "do_shuffle() (in module graphnet.data.dataloader)": [[6, "graphnet.data.dataloader.do_shuffle"]], "from_dataset_config() (graphnet.data.dataloader.dataloader class method)": [[6, "graphnet.data.dataloader.DataLoader.from_dataset_config"]], "graphnet.data.dataloader": [[6, "module-graphnet.data.dataloader"]], "graphnet.data.dataset": [[7, "module-graphnet.data.dataset"]], "columnmissingexception": [[8, "graphnet.data.dataset.dataset.ColumnMissingException"]], "dataset (class in graphnet.data.dataset.dataset)": [[8, "graphnet.data.dataset.dataset.Dataset"]], "ensembledataset (class in graphnet.data.dataset.dataset)": [[8, "graphnet.data.dataset.dataset.EnsembleDataset"]], "add_label() (graphnet.data.dataset.dataset.dataset method)": [[8, "graphnet.data.dataset.dataset.Dataset.add_label"]], "concatenate() (graphnet.data.dataset.dataset.dataset class method)": [[8, "graphnet.data.dataset.dataset.Dataset.concatenate"]], "from_config() (graphnet.data.dataset.dataset.dataset class method)": [[8, "graphnet.data.dataset.dataset.Dataset.from_config"]], "graphnet.data.dataset.dataset": [[8, "module-graphnet.data.dataset.dataset"]], "load_module() (in module graphnet.data.dataset.dataset)": [[8, "graphnet.data.dataset.dataset.load_module"]], "parse_graph_definition() (in module graphnet.data.dataset.dataset)": [[8, "graphnet.data.dataset.dataset.parse_graph_definition"]], "path (graphnet.data.dataset.dataset.dataset property)": [[8, "graphnet.data.dataset.dataset.Dataset.path"]], "query_table() (graphnet.data.dataset.dataset.dataset method)": [[8, "graphnet.data.dataset.dataset.Dataset.query_table"]], "truth_table (graphnet.data.dataset.dataset.dataset property)": [[8, "graphnet.data.dataset.dataset.Dataset.truth_table"]], "graphnet.data.dataset.parquet": [[9, "module-graphnet.data.dataset.parquet"]], "parquetdataset (class in graphnet.data.dataset.parquet.parquet_dataset)": [[10, "graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset"]], "graphnet.data.dataset.parquet.parquet_dataset": [[10, "module-graphnet.data.dataset.parquet.parquet_dataset"]], "query_table() (graphnet.data.dataset.parquet.parquet_dataset.parquetdataset method)": [[10, "graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table"]], "graphnet.data.dataset.sqlite": [[11, "module-graphnet.data.dataset.sqlite"]], "sqlitedataset (class in graphnet.data.dataset.sqlite.sqlite_dataset)": [[12, "graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset"]], "graphnet.data.dataset.sqlite.sqlite_dataset": [[12, "module-graphnet.data.dataset.sqlite.sqlite_dataset"]], "query_table() (graphnet.data.dataset.sqlite.sqlite_dataset.sqlitedataset method)": [[12, "graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table"]], "sqlitedatasetperturbed (class in graphnet.data.dataset.sqlite.sqlite_dataset_perturbed)": [[13, "graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed"]], "graphnet.data.dataset.sqlite.sqlite_dataset_perturbed": [[13, "module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed"]], "graphnet.data.extractors": [[14, "module-graphnet.data.extractors"]], "i3extractor (class in graphnet.data.extractors.i3extractor)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor"]], "i3extractorcollection (class in graphnet.data.extractors.i3extractor)": [[15, "graphnet.data.extractors.i3extractor.I3ExtractorCollection"]], "graphnet.data.extractors.i3extractor": [[15, "module-graphnet.data.extractors.i3extractor"]], "name (graphnet.data.extractors.i3extractor.i3extractor property)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor.name"]], "set_files() (graphnet.data.extractors.i3extractor.i3extractor method)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor.set_files"]], "set_files() (graphnet.data.extractors.i3extractor.i3extractorcollection method)": [[15, "graphnet.data.extractors.i3extractor.I3ExtractorCollection.set_files"]], "i3featureextractor (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"]], "i3featureextractoricecube86 (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCube86"]], "i3featureextractoricecubedeepcore (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCubeDeepCore"]], "i3featureextractoricecubeupgrade (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCubeUpgrade"]], "i3pulsenoisetruthflagicecubeupgrade (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3PulseNoiseTruthFlagIceCubeUpgrade"]], "graphnet.data.extractors.i3featureextractor": [[16, "module-graphnet.data.extractors.i3featureextractor"]], "i3genericextractor (class in graphnet.data.extractors.i3genericextractor)": [[17, "graphnet.data.extractors.i3genericextractor.I3GenericExtractor"]], "graphnet.data.extractors.i3genericextractor": [[17, "module-graphnet.data.extractors.i3genericextractor"]], "i3galacticplanehybridrecoextractor (class in graphnet.data.extractors.i3hybridrecoextractor)": [[18, "graphnet.data.extractors.i3hybridrecoextractor.I3GalacticPlaneHybridRecoExtractor"]], "graphnet.data.extractors.i3hybridrecoextractor": [[18, "module-graphnet.data.extractors.i3hybridrecoextractor"]], "i3ntmuonlabelextractor (class in graphnet.data.extractors.i3ntmuonlabelsextractor)": [[19, "graphnet.data.extractors.i3ntmuonlabelsextractor.I3NTMuonLabelExtractor"]], "graphnet.data.extractors.i3ntmuonlabelsextractor": [[19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"]], "i3particleextractor (class in graphnet.data.extractors.i3particleextractor)": [[20, "graphnet.data.extractors.i3particleextractor.I3ParticleExtractor"]], "graphnet.data.extractors.i3particleextractor": [[20, "module-graphnet.data.extractors.i3particleextractor"]], "i3pisaextractor (class in graphnet.data.extractors.i3pisaextractor)": [[21, "graphnet.data.extractors.i3pisaextractor.I3PISAExtractor"]], "graphnet.data.extractors.i3pisaextractor": [[21, "module-graphnet.data.extractors.i3pisaextractor"]], "i3quesoextractor (class in graphnet.data.extractors.i3quesoextractor)": [[22, "graphnet.data.extractors.i3quesoextractor.I3QUESOExtractor"]], "graphnet.data.extractors.i3quesoextractor": [[22, "module-graphnet.data.extractors.i3quesoextractor"]], "i3retroextractor (class in graphnet.data.extractors.i3retroextractor)": [[23, "graphnet.data.extractors.i3retroextractor.I3RetroExtractor"]], "graphnet.data.extractors.i3retroextractor": [[23, "module-graphnet.data.extractors.i3retroextractor"]], "i3splinempeicextractor (class in graphnet.data.extractors.i3splinempeextractor)": [[24, "graphnet.data.extractors.i3splinempeextractor.I3SplineMPEICExtractor"]], "graphnet.data.extractors.i3splinempeextractor": [[24, "module-graphnet.data.extractors.i3splinempeextractor"]], "i3truthextractor (class in graphnet.data.extractors.i3truthextractor)": [[25, "graphnet.data.extractors.i3truthextractor.I3TruthExtractor"]], "graphnet.data.extractors.i3truthextractor": [[25, "module-graphnet.data.extractors.i3truthextractor"]], "i3tumextractor (class in graphnet.data.extractors.i3tumextractor)": [[26, "graphnet.data.extractors.i3tumextractor.I3TUMExtractor"]], "graphnet.data.extractors.i3tumextractor": [[26, "module-graphnet.data.extractors.i3tumextractor"]], "graphnet.data.extractors.utilities": [[27, "module-graphnet.data.extractors.utilities"]], "flatten_nested_dictionary() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.flatten_nested_dictionary"]], "graphnet.data.extractors.utilities.collections": [[28, "module-graphnet.data.extractors.utilities.collections"]], "serialise() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.serialise"]], "transpose_list_of_dicts() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.transpose_list_of_dicts"]], "frame_is_montecarlo() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.frame_is_montecarlo"]], "frame_is_noise() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.frame_is_noise"]], "get_om_keys_and_pulseseries() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.get_om_keys_and_pulseseries"]], "graphnet.data.extractors.utilities.frames": [[29, "module-graphnet.data.extractors.utilities.frames"]], "break_cyclic_recursion() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.break_cyclic_recursion"]], "cast_object_to_pure_python() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.cast_object_to_pure_python"]], "cast_pulse_series_to_pure_python() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.cast_pulse_series_to_pure_python"]], "get_member_variables() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.get_member_variables"]], "graphnet.data.extractors.utilities.types": [[30, "module-graphnet.data.extractors.utilities.types"]], "is_boost_class() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_boost_class"]], "is_boost_enum() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_boost_enum"]], "is_icecube_class() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_icecube_class"]], "is_method() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_method"]], "is_type() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_type"]], "graphnet.data.parquet": [[31, "module-graphnet.data.parquet"]], "parquetdataconverter (class in graphnet.data.parquet.parquet_dataconverter)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter"]], "file_suffix (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter attribute)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.file_suffix"]], "graphnet.data.parquet.parquet_dataconverter": [[32, "module-graphnet.data.parquet.parquet_dataconverter"]], "merge_files() (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter method)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.merge_files"]], "save_data() (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter method)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.save_data"]], "insqlitepipeline (class in graphnet.data.pipeline)": [[33, "graphnet.data.pipeline.InSQLitePipeline"]], "graphnet.data.pipeline": [[33, "module-graphnet.data.pipeline"]], "graphnet.data.sqlite": [[34, "module-graphnet.data.sqlite"]], "sqlitedataconverter (class in graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter"]], "any_pulsemap_is_non_empty() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.any_pulsemap_is_non_empty"]], "construct_dataframe() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.construct_dataframe"]], "file_suffix (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter attribute)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.file_suffix"]], "graphnet.data.sqlite.sqlite_dataconverter": [[35, "module-graphnet.data.sqlite.sqlite_dataconverter"]], "is_mc_tree() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.is_mc_tree"]], "is_pulse_map() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.is_pulse_map"]], "merge_files() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.merge_files"]], "save_data() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.save_data"]], "attach_index() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.attach_index"]], "create_table() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.create_table"]], "create_table_and_save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.create_table_and_save_to_sql"]], "database_exists() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.database_exists"]], "database_table_exists() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.database_table_exists"]], "graphnet.data.sqlite.sqlite_utilities": [[36, "module-graphnet.data.sqlite.sqlite_utilities"]], "run_sql_code() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.run_sql_code"]], "save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.save_to_sql"]], "graphnet.data.utilities": [[37, "module-graphnet.data.utilities"]], "parquettosqliteconverter (class in graphnet.data.utilities.parquet_to_sqlite)": [[38, "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter"]], "graphnet.data.utilities.parquet_to_sqlite": [[38, "module-graphnet.data.utilities.parquet_to_sqlite"]], "run() (graphnet.data.utilities.parquet_to_sqlite.parquettosqliteconverter method)": [[38, "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter.run"]], "graphnet.data.utilities.random": [[39, "module-graphnet.data.utilities.random"]], "pairwise_shuffle() (in module graphnet.data.utilities.random)": [[39, "graphnet.data.utilities.random.pairwise_shuffle"]], "stringselectionresolver (class in graphnet.data.utilities.string_selection_resolver)": [[40, "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver"]], "graphnet.data.utilities.string_selection_resolver": [[40, "module-graphnet.data.utilities.string_selection_resolver"]], "resolve() (graphnet.data.utilities.string_selection_resolver.stringselectionresolver method)": [[40, "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver.resolve"]], "graphnet.deployment": [[41, "module-graphnet.deployment"]], "graphneti3module (class in graphnet.deployment.i3modules.graphnet_module)": [[44, "graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module"]], "i3inferencemodule (class in graphnet.deployment.i3modules.graphnet_module)": [[44, "graphnet.deployment.i3modules.graphnet_module.I3InferenceModule"]], "i3pulsecleanermodule (class in graphnet.deployment.i3modules.graphnet_module)": [[44, "graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule"]], "graphnet.deployment.i3modules.graphnet_module": [[44, "module-graphnet.deployment.i3modules.graphnet_module"]], "graphnet.models": [[45, "module-graphnet.models"]], "attributecoarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.AttributeCoarsening"]], "coarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.Coarsening"]], "customdomcoarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.CustomDOMCoarsening"]], "domandtimewindowcoarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.DOMAndTimeWindowCoarsening"]], "domcoarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.DOMCoarsening"]], "forward() (graphnet.models.coarsening.coarsening method)": [[46, "graphnet.models.coarsening.Coarsening.forward"]], "graphnet.models.coarsening": [[46, "module-graphnet.models.coarsening"]], "reduce_options (graphnet.models.coarsening.coarsening attribute)": [[46, "graphnet.models.coarsening.Coarsening.reduce_options"]], "unbatch_edge_index() (in module graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.unbatch_edge_index"]], "graphnet.models.components": [[47, "module-graphnet.models.components"]], "dynedgeconv (class in graphnet.models.components.layers)": [[48, "graphnet.models.components.layers.DynEdgeConv"]], "dyntrans (class in graphnet.models.components.layers)": [[48, "graphnet.models.components.layers.DynTrans"]], "edgeconvtito (class in graphnet.models.components.layers)": [[48, "graphnet.models.components.layers.EdgeConvTito"]], "forward() (graphnet.models.components.layers.dynedgeconv method)": [[48, "graphnet.models.components.layers.DynEdgeConv.forward"]], "forward() (graphnet.models.components.layers.dyntrans method)": [[48, "graphnet.models.components.layers.DynTrans.forward"]], "forward() (graphnet.models.components.layers.edgeconvtito method)": [[48, "graphnet.models.components.layers.EdgeConvTito.forward"]], "graphnet.models.components.layers": [[48, "module-graphnet.models.components.layers"]], "message() (graphnet.models.components.layers.edgeconvtito method)": [[48, "graphnet.models.components.layers.EdgeConvTito.message"]], "reset_parameters() (graphnet.models.components.layers.edgeconvtito method)": [[48, "graphnet.models.components.layers.EdgeConvTito.reset_parameters"]], "graphnet.models.components.pool": [[49, "module-graphnet.models.components.pool"]], "group_by() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.group_by"]], "group_pulses_to_dom() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.group_pulses_to_dom"]], "group_pulses_to_pmt() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.group_pulses_to_pmt"]], "min_pool() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.min_pool"]], "min_pool_x() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.min_pool_x"]], "std_pool() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.std_pool"]], "std_pool_x() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.std_pool_x"]], "sum_pool() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.sum_pool"]], "sum_pool_and_distribute() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.sum_pool_and_distribute"]], "sum_pool_x() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.sum_pool_x"]], "graphnet.models.detector": [[50, "module-graphnet.models.detector"]], "detector (class in graphnet.models.detector.detector)": [[51, "graphnet.models.detector.detector.Detector"]], "feature_map() (graphnet.models.detector.detector.detector method)": [[51, "graphnet.models.detector.detector.Detector.feature_map"]], "forward() (graphnet.models.detector.detector.detector method)": [[51, "graphnet.models.detector.detector.Detector.forward"]], "graphnet.models.detector.detector": [[51, "module-graphnet.models.detector.detector"]], "icecube86 (class in graphnet.models.detector.icecube)": [[52, "graphnet.models.detector.icecube.IceCube86"]], "icecubedeepcore (class in graphnet.models.detector.icecube)": [[52, "graphnet.models.detector.icecube.IceCubeDeepCore"]], "icecubekaggle (class in graphnet.models.detector.icecube)": [[52, "graphnet.models.detector.icecube.IceCubeKaggle"]], "icecubeupgrade (class in graphnet.models.detector.icecube)": [[52, "graphnet.models.detector.icecube.IceCubeUpgrade"]], "feature_map() (graphnet.models.detector.icecube.icecube86 method)": [[52, "graphnet.models.detector.icecube.IceCube86.feature_map"]], "feature_map() (graphnet.models.detector.icecube.icecubedeepcore method)": [[52, "graphnet.models.detector.icecube.IceCubeDeepCore.feature_map"]], "feature_map() (graphnet.models.detector.icecube.icecubekaggle method)": [[52, "graphnet.models.detector.icecube.IceCubeKaggle.feature_map"]], "feature_map() (graphnet.models.detector.icecube.icecubeupgrade method)": [[52, "graphnet.models.detector.icecube.IceCubeUpgrade.feature_map"]], "graphnet.models.detector.icecube": [[52, "module-graphnet.models.detector.icecube"]], "prometheus (class in graphnet.models.detector.prometheus)": [[53, "graphnet.models.detector.prometheus.Prometheus"]], "feature_map() (graphnet.models.detector.prometheus.prometheus method)": [[53, "graphnet.models.detector.prometheus.Prometheus.feature_map"]], "graphnet.models.detector.prometheus": [[53, "module-graphnet.models.detector.prometheus"]], "graphnet.models.gnn": [[54, "module-graphnet.models.gnn"]], "convnet (class in graphnet.models.gnn.convnet)": [[55, "graphnet.models.gnn.convnet.ConvNet"]], "forward() (graphnet.models.gnn.convnet.convnet method)": [[55, "graphnet.models.gnn.convnet.ConvNet.forward"]], "graphnet.models.gnn.convnet": [[55, "module-graphnet.models.gnn.convnet"]], "dynedge (class in graphnet.models.gnn.dynedge)": [[56, "graphnet.models.gnn.dynedge.DynEdge"]], "forward() (graphnet.models.gnn.dynedge.dynedge method)": [[56, "graphnet.models.gnn.dynedge.DynEdge.forward"]], "graphnet.models.gnn.dynedge": [[56, "module-graphnet.models.gnn.dynedge"]], "dynedgejinst (class in graphnet.models.gnn.dynedge_jinst)": [[57, "graphnet.models.gnn.dynedge_jinst.DynEdgeJINST"]], "forward() (graphnet.models.gnn.dynedge_jinst.dynedgejinst method)": [[57, "graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward"]], "graphnet.models.gnn.dynedge_jinst": [[57, "module-graphnet.models.gnn.dynedge_jinst"]], "dynedgetito (class in graphnet.models.gnn.dynedge_kaggle_tito)": [[58, "graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO"]], "forward() (graphnet.models.gnn.dynedge_kaggle_tito.dynedgetito method)": [[58, "graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward"]], "graphnet.models.gnn.dynedge_kaggle_tito": [[58, "module-graphnet.models.gnn.dynedge_kaggle_tito"]], "gnn (class in graphnet.models.gnn.gnn)": [[59, "graphnet.models.gnn.gnn.GNN"]], "forward() (graphnet.models.gnn.gnn.gnn method)": [[59, "graphnet.models.gnn.gnn.GNN.forward"]], "graphnet.models.gnn.gnn": [[59, "module-graphnet.models.gnn.gnn"]], "nb_inputs (graphnet.models.gnn.gnn.gnn property)": [[59, "graphnet.models.gnn.gnn.GNN.nb_inputs"]], "nb_outputs (graphnet.models.gnn.gnn.gnn property)": [[59, "graphnet.models.gnn.gnn.GNN.nb_outputs"]], "graphnet.models.graphs": [[60, "module-graphnet.models.graphs"]], "graphnet.models.graphs.edges": [[61, "module-graphnet.models.graphs.edges"]], "edgedefinition (class in graphnet.models.graphs.edges.edges)": [[62, "graphnet.models.graphs.edges.edges.EdgeDefinition"]], "euclideanedges (class in graphnet.models.graphs.edges.edges)": [[62, "graphnet.models.graphs.edges.edges.EuclideanEdges"]], "knnedges (class in graphnet.models.graphs.edges.edges)": [[62, "graphnet.models.graphs.edges.edges.KNNEdges"]], "radialedges (class in graphnet.models.graphs.edges.edges)": [[62, "graphnet.models.graphs.edges.edges.RadialEdges"]], "forward() (graphnet.models.graphs.edges.edges.edgedefinition method)": [[62, "graphnet.models.graphs.edges.edges.EdgeDefinition.forward"]], "graphnet.models.graphs.edges.edges": [[62, "module-graphnet.models.graphs.edges.edges"]], "graphdefinition (class in graphnet.models.graphs.graph_definition)": [[63, "graphnet.models.graphs.graph_definition.GraphDefinition"]], "forward() (graphnet.models.graphs.graph_definition.graphdefinition method)": [[63, "graphnet.models.graphs.graph_definition.GraphDefinition.forward"]], "graphnet.models.graphs.graph_definition": [[63, "module-graphnet.models.graphs.graph_definition"]], "knngraph (class in graphnet.models.graphs.graphs)": [[64, "graphnet.models.graphs.graphs.KNNGraph"]], "graphnet.models.graphs.graphs": [[64, "module-graphnet.models.graphs.graphs"]], "graphnet.models.graphs.nodes": [[65, "module-graphnet.models.graphs.nodes"]], "nodedefinition (class in graphnet.models.graphs.nodes.nodes)": [[66, "graphnet.models.graphs.nodes.nodes.NodeDefinition"]], "nodesaspulses (class in graphnet.models.graphs.nodes.nodes)": [[66, "graphnet.models.graphs.nodes.nodes.NodesAsPulses"]], "forward() (graphnet.models.graphs.nodes.nodes.nodedefinition method)": [[66, "graphnet.models.graphs.nodes.nodes.NodeDefinition.forward"]], "graphnet.models.graphs.nodes.nodes": [[66, "module-graphnet.models.graphs.nodes.nodes"]], "nb_outputs (graphnet.models.graphs.nodes.nodes.nodedefinition property)": [[66, "graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs"]], "set_number_of_inputs() (graphnet.models.graphs.nodes.nodes.nodedefinition method)": [[66, "graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs"]], "model (class in graphnet.models.model)": [[67, "graphnet.models.model.Model"]], "fit() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.fit"]], "forward() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.forward"]], "from_config() (graphnet.models.model.model class method)": [[67, "graphnet.models.model.Model.from_config"]], "graphnet.models.model": [[67, "module-graphnet.models.model"]], "load() (graphnet.models.model.model class method)": [[67, "graphnet.models.model.Model.load"]], "load_state_dict() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.load_state_dict"]], "predict() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.predict"]], "predict_as_dataframe() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.predict_as_dataframe"]], "save() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.save"]], "save_state_dict() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.save_state_dict"]], "standardmodel (class in graphnet.models.standard_model)": [[68, "graphnet.models.standard_model.StandardModel"]], "compute_loss() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.compute_loss"]], "configure_optimizers() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.configure_optimizers"]], "forward() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.forward"]], "graphnet.models.standard_model": [[68, "module-graphnet.models.standard_model"]], "inference() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.inference"]], "predict() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.predict"]], "predict_as_dataframe() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.predict_as_dataframe"]], "prediction_labels (graphnet.models.standard_model.standardmodel property)": [[68, "graphnet.models.standard_model.StandardModel.prediction_labels"]], "shared_step() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.shared_step"]], "target_labels (graphnet.models.standard_model.standardmodel property)": [[68, "graphnet.models.standard_model.StandardModel.target_labels"]], "train() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.train"]], "training_step() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.training_step"]], "validation_step() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.validation_step"]], "graphnet.models.task": [[69, "module-graphnet.models.task"]], "binaryclassificationtask (class in graphnet.models.task.classification)": [[70, "graphnet.models.task.classification.BinaryClassificationTask"]], "binaryclassificationtasklogits (class in graphnet.models.task.classification)": [[70, "graphnet.models.task.classification.BinaryClassificationTaskLogits"]], "multiclassclassificationtask (class in graphnet.models.task.classification)": [[70, "graphnet.models.task.classification.MulticlassClassificationTask"]], "default_prediction_labels (graphnet.models.task.classification.binaryclassificationtask attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.classification.binaryclassificationtasklogits attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels"]], "default_target_labels (graphnet.models.task.classification.binaryclassificationtask attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTask.default_target_labels"]], "default_target_labels (graphnet.models.task.classification.binaryclassificationtasklogits attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels"]], "graphnet.models.task.classification": [[70, "module-graphnet.models.task.classification"]], "nb_inputs (graphnet.models.task.classification.binaryclassificationtask attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTask.nb_inputs"]], "nb_inputs (graphnet.models.task.classification.binaryclassificationtasklogits attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs"]], "azimuthreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstruction"]], "azimuthreconstructionwithkappa (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa"]], "directionreconstructionwithkappa (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa"]], "energyreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.EnergyReconstruction"]], "energyreconstructionwithpower (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithPower"]], "energyreconstructionwithuncertainty (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty"]], "inelasticityreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.InelasticityReconstruction"]], "positionreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.PositionReconstruction"]], "timereconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.TimeReconstruction"]], "vertexreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.VertexReconstruction"]], "zenithreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.ZenithReconstruction"]], "zenithreconstructionwithkappa (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa"]], "default_prediction_labels (graphnet.models.task.reconstruction.azimuthreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.azimuthreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.directionreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.energyreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.energyreconstructionwithpower attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.energyreconstructionwithuncertainty attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.inelasticityreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.positionreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.timereconstruction attribute)": [[71, "graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.vertexreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.zenithreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.zenithreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels"]], "default_target_labels (graphnet.models.task.reconstruction.azimuthreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.azimuthreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.directionreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.energyreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.energyreconstructionwithpower attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.energyreconstructionwithuncertainty attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.inelasticityreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.positionreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.timereconstruction attribute)": [[71, "graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.vertexreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.zenithreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.zenithreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels"]], "graphnet.models.task.reconstruction": [[71, "module-graphnet.models.task.reconstruction"]], "nb_inputs (graphnet.models.task.reconstruction.azimuthreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.azimuthreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.directionreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.energyreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.energyreconstructionwithpower attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.energyreconstructionwithuncertainty attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.inelasticityreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.positionreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.timereconstruction attribute)": [[71, "graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.vertexreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.zenithreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.zenithreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs"]], "identitytask (class in graphnet.models.task.task)": [[72, "graphnet.models.task.task.IdentityTask"]], "task (class in graphnet.models.task.task)": [[72, "graphnet.models.task.task.Task"]], "compute_loss() (graphnet.models.task.task.task method)": [[72, "graphnet.models.task.task.Task.compute_loss"]], "default_prediction_labels (graphnet.models.task.task.identitytask property)": [[72, "graphnet.models.task.task.IdentityTask.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.task.task property)": [[72, "graphnet.models.task.task.Task.default_prediction_labels"]], "default_target_labels (graphnet.models.task.task.identitytask property)": [[72, "graphnet.models.task.task.IdentityTask.default_target_labels"]], "default_target_labels (graphnet.models.task.task.task property)": [[72, "graphnet.models.task.task.Task.default_target_labels"]], "forward() (graphnet.models.task.task.task method)": [[72, "graphnet.models.task.task.Task.forward"]], "graphnet.models.task.task": [[72, "module-graphnet.models.task.task"]], "inference() (graphnet.models.task.task.task method)": [[72, "graphnet.models.task.task.Task.inference"]], "nb_inputs (graphnet.models.task.task.identitytask property)": [[72, "graphnet.models.task.task.IdentityTask.nb_inputs"]], "nb_inputs (graphnet.models.task.task.task property)": [[72, "graphnet.models.task.task.Task.nb_inputs"]], "train_eval() (graphnet.models.task.task.task method)": [[72, "graphnet.models.task.task.Task.train_eval"]], "calculate_distance_matrix() (in module graphnet.models.utils)": [[73, "graphnet.models.utils.calculate_distance_matrix"]], "calculate_xyzt_homophily() (in module graphnet.models.utils)": [[73, "graphnet.models.utils.calculate_xyzt_homophily"]], "graphnet.models.utils": [[73, "module-graphnet.models.utils"]], "knn_graph_batch() (in module graphnet.models.utils)": [[73, "graphnet.models.utils.knn_graph_batch"]], "graphnet.pisa": [[74, "module-graphnet.pisa"]], "contourfitter (class in graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.ContourFitter"]], "weightfitter (class in graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.WeightFitter"]], "config_updater() (in module graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.config_updater"]], "fit_1d_contour() (graphnet.pisa.fitting.contourfitter method)": [[75, "graphnet.pisa.fitting.ContourFitter.fit_1d_contour"]], "fit_2d_contour() (graphnet.pisa.fitting.contourfitter method)": [[75, "graphnet.pisa.fitting.ContourFitter.fit_2d_contour"]], "fit_weights() (graphnet.pisa.fitting.weightfitter method)": [[75, "graphnet.pisa.fitting.WeightFitter.fit_weights"]], "graphnet.pisa.fitting": [[75, "module-graphnet.pisa.fitting"]], "graphnet.pisa.plotting": [[76, "module-graphnet.pisa.plotting"]], "plot_1d_contour() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.plot_1D_contour"]], "plot_2d_contour() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.plot_2D_contour"]], "read_entry() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.read_entry"]], "graphnet.training": [[77, "module-graphnet.training"]], "piecewiselinearlr (class in graphnet.training.callbacks)": [[78, "graphnet.training.callbacks.PiecewiseLinearLR"]], "progressbar (class in graphnet.training.callbacks)": [[78, "graphnet.training.callbacks.ProgressBar"]], "get_lr() (graphnet.training.callbacks.piecewiselinearlr method)": [[78, "graphnet.training.callbacks.PiecewiseLinearLR.get_lr"]], "get_metrics() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.get_metrics"]], "graphnet.training.callbacks": [[78, "module-graphnet.training.callbacks"]], "init_predict_tqdm() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.init_predict_tqdm"]], "init_test_tqdm() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.init_test_tqdm"]], "init_train_tqdm() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.init_train_tqdm"]], "init_validation_tqdm() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.init_validation_tqdm"]], "on_train_epoch_end() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.on_train_epoch_end"]], "on_train_epoch_start() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.on_train_epoch_start"]], "direction (class in graphnet.training.labels)": [[79, "graphnet.training.labels.Direction"]], "label (class in graphnet.training.labels)": [[79, "graphnet.training.labels.Label"]], "graphnet.training.labels": [[79, "module-graphnet.training.labels"]], "key (graphnet.training.labels.label property)": [[79, "graphnet.training.labels.Label.key"]], "binarycrossentropyloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.BinaryCrossEntropyLoss"]], "crossentropyloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.CrossEntropyLoss"]], "euclideandistanceloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.EuclideanDistanceLoss"]], "logcmk (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.LogCMK"]], "logcoshloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.LogCoshLoss"]], "lossfunction (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.LossFunction"]], "mseloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.MSELoss"]], "rmseloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.RMSELoss"]], "vonmisesfisher2dloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.VonMisesFisher2DLoss"]], "vonmisesfisher3dloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.VonMisesFisher3DLoss"]], "vonmisesfisherloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.VonMisesFisherLoss"]], "backward() (graphnet.training.loss_functions.logcmk static method)": [[80, "graphnet.training.loss_functions.LogCMK.backward"]], "forward() (graphnet.training.loss_functions.logcmk static method)": [[80, "graphnet.training.loss_functions.LogCMK.forward"]], "forward() (graphnet.training.loss_functions.lossfunction method)": [[80, "graphnet.training.loss_functions.LossFunction.forward"]], "graphnet.training.loss_functions": [[80, "module-graphnet.training.loss_functions"]], "log_cmk() (graphnet.training.loss_functions.vonmisesfisherloss class method)": [[80, "graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk"]], "log_cmk_approx() (graphnet.training.loss_functions.vonmisesfisherloss class method)": [[80, "graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx"]], "log_cmk_exact() (graphnet.training.loss_functions.vonmisesfisherloss class method)": [[80, "graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact"]], "collate_fn() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.collate_fn"]], "get_predictions() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.get_predictions"]], "graphnet.training.utils": [[81, "module-graphnet.training.utils"]], "make_dataloader() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.make_dataloader"]], "make_train_validation_dataloader() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.make_train_validation_dataloader"]], "save_results() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.save_results"]], "bjoernlow (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.BjoernLow"]], "uniform (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.Uniform"]], "weightfitter (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.WeightFitter"]], "fit() (graphnet.training.weight_fitting.weightfitter method)": [[82, "graphnet.training.weight_fitting.WeightFitter.fit"]], "graphnet.training.weight_fitting": [[82, "module-graphnet.training.weight_fitting"]], "graphnet.utilities": [[83, "module-graphnet.utilities"]], "argumentparser (class in graphnet.utilities.argparse)": [[84, "graphnet.utilities.argparse.ArgumentParser"]], "options (class in graphnet.utilities.argparse)": [[84, "graphnet.utilities.argparse.Options"]], "contains() (graphnet.utilities.argparse.options method)": [[84, "graphnet.utilities.argparse.Options.contains"]], "graphnet.utilities.argparse": [[84, "module-graphnet.utilities.argparse"]], "pop_default() (graphnet.utilities.argparse.options method)": [[84, "graphnet.utilities.argparse.Options.pop_default"]], "standard_arguments (graphnet.utilities.argparse.argumentparser attribute)": [[84, "graphnet.utilities.argparse.ArgumentParser.standard_arguments"]], "with_standard_arguments() (graphnet.utilities.argparse.argumentparser method)": [[84, "graphnet.utilities.argparse.ArgumentParser.with_standard_arguments"]], "graphnet.utilities.config": [[85, "module-graphnet.utilities.config"]], "baseconfig (class in graphnet.utilities.config.base_config)": [[86, "graphnet.utilities.config.base_config.BaseConfig"]], "as_dict() (graphnet.utilities.config.base_config.baseconfig method)": [[86, "graphnet.utilities.config.base_config.BaseConfig.as_dict"]], "dump() (graphnet.utilities.config.base_config.baseconfig method)": [[86, "graphnet.utilities.config.base_config.BaseConfig.dump"]], "get_all_argument_values() (in module graphnet.utilities.config.base_config)": [[86, "graphnet.utilities.config.base_config.get_all_argument_values"]], "graphnet.utilities.config.base_config": [[86, "module-graphnet.utilities.config.base_config"]], "load() (graphnet.utilities.config.base_config.baseconfig class method)": [[86, "graphnet.utilities.config.base_config.BaseConfig.load"]], "model_config (graphnet.utilities.config.base_config.baseconfig attribute)": [[86, "graphnet.utilities.config.base_config.BaseConfig.model_config"]], "model_fields (graphnet.utilities.config.base_config.baseconfig attribute)": [[86, "graphnet.utilities.config.base_config.BaseConfig.model_fields"]], "configurable (class in graphnet.utilities.config.configurable)": [[87, "graphnet.utilities.config.configurable.Configurable"]], "config (graphnet.utilities.config.configurable.configurable property)": [[87, "graphnet.utilities.config.configurable.Configurable.config"]], "from_config() (graphnet.utilities.config.configurable.configurable class method)": [[87, "graphnet.utilities.config.configurable.Configurable.from_config"]], "graphnet.utilities.config.configurable": [[87, "module-graphnet.utilities.config.configurable"]], "save_config() (graphnet.utilities.config.configurable.configurable method)": [[87, "graphnet.utilities.config.configurable.Configurable.save_config"]], "datasetconfig (class in graphnet.utilities.config.dataset_config)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig"]], "as_dict() (graphnet.utilities.config.dataset_config.datasetconfig method)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.as_dict"]], "features (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.features"]], "graph_definition (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition"]], "graphnet.utilities.config.dataset_config": [[88, "module-graphnet.utilities.config.dataset_config"]], "index_column (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.index_column"]], "loss_weight_column (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column"]], "loss_weight_default_value (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value"]], "loss_weight_table (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table"]], "model_config (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.model_config"]], "model_fields (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.model_fields"]], "node_truth (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.node_truth"]], "node_truth_table (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table"]], "path (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.path"]], "pulsemaps (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps"]], "save_dataset_config() (in module graphnet.utilities.config.dataset_config)": [[88, "graphnet.utilities.config.dataset_config.save_dataset_config"]], "seed (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.seed"]], "selection (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.selection"]], "string_selection (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.string_selection"]], "truth (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.truth"]], "truth_table (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.truth_table"]], "modelconfig (class in graphnet.utilities.config.model_config)": [[89, "graphnet.utilities.config.model_config.ModelConfig"]], "arguments (graphnet.utilities.config.model_config.modelconfig attribute)": [[89, "graphnet.utilities.config.model_config.ModelConfig.arguments"]], "as_dict() (graphnet.utilities.config.model_config.modelconfig method)": [[89, "graphnet.utilities.config.model_config.ModelConfig.as_dict"]], "class_name (graphnet.utilities.config.model_config.modelconfig attribute)": [[89, "graphnet.utilities.config.model_config.ModelConfig.class_name"]], "graphnet.utilities.config.model_config": [[89, "module-graphnet.utilities.config.model_config"]], "model_config (graphnet.utilities.config.model_config.modelconfig attribute)": [[89, "graphnet.utilities.config.model_config.ModelConfig.model_config"]], "model_fields (graphnet.utilities.config.model_config.modelconfig attribute)": [[89, "graphnet.utilities.config.model_config.ModelConfig.model_fields"]], "save_model_config() (in module graphnet.utilities.config.model_config)": [[89, "graphnet.utilities.config.model_config.save_model_config"]], "get_all_grapnet_classes() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.get_all_grapnet_classes"]], "get_graphnet_classes() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.get_graphnet_classes"]], "graphnet.utilities.config.parsing": [[90, "module-graphnet.utilities.config.parsing"]], "is_graphnet_class() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.is_graphnet_class"]], "is_graphnet_module() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.is_graphnet_module"]], "list_all_submodules() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.list_all_submodules"]], "traverse_and_apply() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.traverse_and_apply"]], "trainingconfig (class in graphnet.utilities.config.training_config)": [[91, "graphnet.utilities.config.training_config.TrainingConfig"]], "dataloader (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.dataloader"]], "early_stopping_patience (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience"]], "fit (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.fit"]], "graphnet.utilities.config.training_config": [[91, "module-graphnet.utilities.config.training_config"]], "model_config (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.model_config"]], "model_fields (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.model_fields"]], "target (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.target"]], "graphnet.utilities.decorators": [[92, "module-graphnet.utilities.decorators"]], "find_i3_files() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.find_i3_files"]], "graphnet.utilities.filesys": [[93, "module-graphnet.utilities.filesys"]], "has_extension() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.has_extension"]], "is_gcd_file() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.is_gcd_file"]], "is_i3_file() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.is_i3_file"]], "graphnet.utilities.imports": [[94, "module-graphnet.utilities.imports"]], "has_icecube_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_icecube_package"]], "has_pisa_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_pisa_package"]], "has_torch_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_torch_package"]], "requires_icecube() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.requires_icecube"]], "logger (class in graphnet.utilities.logging)": [[95, "graphnet.utilities.logging.Logger"]], "repeatfilter (class in graphnet.utilities.logging)": [[95, "graphnet.utilities.logging.RepeatFilter"]], "critical() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.critical"]], "debug() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.debug"]], "error() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.error"]], "file_handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.file_handlers"]], "filter() (graphnet.utilities.logging.repeatfilter method)": [[95, "graphnet.utilities.logging.RepeatFilter.filter"]], "graphnet.utilities.logging": [[95, "module-graphnet.utilities.logging"]], "handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.handlers"]], "info() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.info"]], "nb_repeats_allowed (graphnet.utilities.logging.repeatfilter attribute)": [[95, "graphnet.utilities.logging.RepeatFilter.nb_repeats_allowed"]], "setlevel() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.setLevel"]], "stream_handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.stream_handlers"]], "warning() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.warning"]], "warning_once() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.warning_once"]], "eps_like() (in module graphnet.utilities.maths)": [[96, "graphnet.utilities.maths.eps_like"]], "graphnet.utilities.maths": [[96, "module-graphnet.utilities.maths"]]}}) \ No newline at end of file diff --git a/sitemap.xml b/sitemap.xml index 8ca584144..1332adef5 100644 --- a/sitemap.xml +++ b/sitemap.xml @@ -1 +1 @@ -<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"><url><loc>https://graphnet-team.github.io/graphnetabout.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.constants.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.constants.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataloader.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.parquet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.parquet.parquet_dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.sqlite_dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3extractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3featureextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3genericextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3hybridrecoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3ntmuonlabelsextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3particleextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3pisaextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3quesoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3retroextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3splinempeextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3truthextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3tumextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.collections.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.frames.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.types.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.parquet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.parquet.parquet_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.pipeline.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.sqlite_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.sqlite_utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.parquet_to_sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.random.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.string_selection_resolver.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.deployer.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.graphnet_module.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.coarsening.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.components.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.components.layers.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.components.pool.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.detector.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.icecube.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.prometheus.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.convnet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge_jinst.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge_kaggle_tito.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.gnn.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.edges.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.edges.edges.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.graph_definition.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.graphs.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.nodes.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.nodes.nodes.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.model.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.standard_model.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.classification.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.reconstruction.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.task.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.utils.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.pisa.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.pisa.fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.pisa.plotting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.callbacks.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.labels.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.loss_functions.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.utils.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.weight_fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.argparse.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.base_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.configurable.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.dataset_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.model_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.parsing.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.training_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.decorators.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.filesys.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.imports.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.logging.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.maths.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/modules.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetcontribute.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetindex.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetinstall.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetgenindex.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetpy-modindex.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/constants.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3extractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3featureextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3genericextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3hybridrecoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3particleextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3pisaextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3quesoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3retroextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3splinempeextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3truthextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3tumextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/collections.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/frames.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/types.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/parquet/parquet_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/sqlite/sqlite_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/sqlite/sqlite_utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/parquet_to_sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/random.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/string_selection_resolver.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/pisa/fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/pisa/plotting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/training/weight_fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/argparse.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/filesys.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/imports.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/logging.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/index.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetsearch.html</loc></url></urlset> \ No newline at end of file +<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"><url><loc>https://graphnet-team.github.io/graphnetabout.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.constants.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.constants.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataloader.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.parquet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.parquet.parquet_dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.sqlite_dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3extractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3featureextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3genericextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3hybridrecoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3ntmuonlabelsextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3particleextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3pisaextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3quesoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3retroextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3splinempeextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3truthextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3tumextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.collections.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.frames.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.types.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.parquet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.parquet.parquet_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.pipeline.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.sqlite_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.sqlite_utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.parquet_to_sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.random.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.string_selection_resolver.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.deployer.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.graphnet_module.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.coarsening.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.components.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.components.layers.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.components.pool.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.detector.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.icecube.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.detector.prometheus.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.convnet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge_jinst.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge_kaggle_tito.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.gnn.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.edges.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.edges.edges.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.graph_definition.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.graphs.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.nodes.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.nodes.nodes.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.model.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.standard_model.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.classification.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.reconstruction.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.task.task.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.models.utils.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.pisa.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.pisa.fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.pisa.plotting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.callbacks.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.labels.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.loss_functions.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.utils.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.training.weight_fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.argparse.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.base_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.configurable.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.dataset_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.model_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.parsing.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.training_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.decorators.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.filesys.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.imports.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.logging.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/graphnet.utilities.maths.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetapi/modules.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetcontribute.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetindex.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetinstall.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetgenindex.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetpy-modindex.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/constants.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/dataloader.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/dataset/dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/dataset/parquet/parquet_dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3extractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3featureextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3genericextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3hybridrecoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3particleextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3pisaextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3quesoextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3retroextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3splinempeextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3truthextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3tumextractor.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/collections.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/frames.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/types.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/parquet/parquet_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/pipeline.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/sqlite/sqlite_dataconverter.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/sqlite/sqlite_utilities.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/parquet_to_sqlite.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/random.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/string_selection_resolver.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/deployment/i3modules/graphnet_module.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/coarsening.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/components/layers.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/components/pool.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/detector/detector.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/detector/icecube.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/detector/prometheus.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/convnet.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/dynedge.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/dynedge_jinst.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/dynedge_kaggle_tito.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/gnn.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/graphs/edges/edges.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/graphs/graph_definition.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/graphs/graphs.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/graphs/nodes/nodes.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/model.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/standard_model.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/task/classification.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/task/reconstruction.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/task/task.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/models/utils.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/pisa/fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/pisa/plotting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/training/callbacks.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/training/labels.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/training/loss_functions.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/training/utils.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/training/weight_fitting.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/argparse.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/base_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/configurable.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/dataset_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/model_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/parsing.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/training_config.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/filesys.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/imports.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/logging.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/graphnet/utilities/maths.html</loc></url><url><loc>https://graphnet-team.github.io/graphnet_modules/index.html</loc></url><url><loc>https://graphnet-team.github.io/graphnetsearch.html</loc></url></urlset> \ No newline at end of file