From a3b722ede15810d0cf373bfb37826cc37a6d385e Mon Sep 17 00:00:00 2001 From: AMHermansen Date: Wed, 13 Sep 2023 10:41:38 +0200 Subject: [PATCH 01/26] Create MetaClasses to save Model/Dataset configs --- src/graphnet/utilities/config/__init__.py | 4 +- .../utilities/config/dataset_config.py | 50 ++++++++----------- src/graphnet/utilities/config/model_config.py | 47 ++++++++--------- 3 files changed, 45 insertions(+), 56 deletions(-) diff --git a/src/graphnet/utilities/config/__init__.py b/src/graphnet/utilities/config/__init__.py index 5e37c6a00..426eae788 100644 --- a/src/graphnet/utilities/config/__init__.py +++ b/src/graphnet/utilities/config/__init__.py @@ -1,6 +1,6 @@ """Modules for configuration files for use across `graphnet`.""" from .configurable import Configurable -from .dataset_config import DatasetConfig, save_dataset_config -from .model_config import ModelConfig, save_model_config +from .dataset_config import DatasetConfig, DatasetConfigSaverMeta +from .model_config import ModelConfig, ModelConfigSaverMeta from .training_config import TrainingConfig diff --git a/src/graphnet/utilities/config/dataset_config.py b/src/graphnet/utilities/config/dataset_config.py index 34d92fc3c..b48459315 100644 --- a/src/graphnet/utilities/config/dataset_config.py +++ b/src/graphnet/utilities/config/dataset_config.py @@ -178,39 +178,33 @@ def _parse_torch(self, obj: Any) -> Any: return obj -def save_dataset_config(init_fn: Callable) -> Callable: - """Save the arguments to `__init__` functions as member `DatasetConfig`.""" - - def _replace_model_instance_with_config( - obj: Union["Model", Any] - ) -> Union[ModelConfig, Any]: - """Replace `Model` instances in `obj` with their `ModelConfig`.""" - from graphnet.models import Model - import torch +class DatasetConfigSaverMeta(type): + """Metaclass for `DatasetConfig` that saves the config after `__init__`.""" - if isinstance(obj, Model): - return obj.config + def __call__(cls: Any, *args: Any, **kwargs: Any) -> object: + """Catch object construction and save config after `__init__`.""" - if isinstance(obj, torch.dtype): - return obj.__str__() + def _replace_model_instance_with_config( + obj: Union["Model", Any] + ) -> Union[ModelConfig, Any]: + """Replace `Model` instances in `obj` with their `ModelConfig`.""" + from graphnet.models import Model - else: - return obj + if isinstance(obj, Model): + return obj.config + else: + return obj - @wraps(init_fn) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - """Set `DatasetConfig` after calling `init_fn`.""" - # Call wrapped method - ret = init_fn(self, *args, **kwargs) + # Create object + created_obj = super().__call__(*args, **kwargs) # Get all argument values, including defaults - cfg = get_all_argument_values(init_fn, *args, **kwargs) - - # Handle nested `Model`s, etc. + cfg = get_all_argument_values(created_obj.__init__, *args, **kwargs) cfg = traverse_and_apply(cfg, _replace_model_instance_with_config) - # Add `DatasetConfig` as member variables - self._config = DatasetConfig(**cfg) - return ret - - return wrapper + # Store config in + created_obj._config = DatasetConfig( + class_name=str(created_obj.__class__.__name__), + arguments=dict(**cfg), + ) + return created_obj diff --git a/src/graphnet/utilities/config/model_config.py b/src/graphnet/utilities/config/model_config.py index 9c4d21d26..e811a1ab0 100644 --- a/src/graphnet/utilities/config/model_config.py +++ b/src/graphnet/utilities/config/model_config.py @@ -248,38 +248,33 @@ def as_dict(self) -> Dict[str, Dict[str, Any]]: return {self.__class__.__name__: config_dict} -def save_model_config(init_fn: Callable) -> Callable: - """Save the arguments to `__init__` functions as a member `ModelConfig`.""" +class ModelConfigSaverMeta(type): + """Metaclass for saving `ModelConfig` to `Model` instances.""" - def _replace_model_instance_with_config( - obj: Union["Model", Any] - ) -> Union[ModelConfig, Any]: - """Replace `Model` instances in `obj` with their `ModelConfig`.""" - from graphnet.models import Model + def __call__(cls: Any, *args: Any, **kwargs: Any) -> object: + """Catch object construction and save config after `__init__`.""" - if isinstance(obj, Model): - return obj.config - else: - return obj + def _replace_model_instance_with_config( + obj: Union["Model", Any] + ) -> Union[ModelConfig, Any]: + """Replace `Model` instances in `obj` with their `ModelConfig`.""" + from graphnet.models import Model + + if isinstance(obj, Model): + return obj.config + else: + return obj - @wraps(init_fn) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - """Set `ModelConfig` after calling `init_fn`.""" - # Call wrapped method - ret = init_fn(self, *args, **kwargs) + # Create object + created_obj = super().__call__(*args, **kwargs) # Get all argument values, including defaults - cfg = get_all_argument_values(init_fn, *args, **kwargs) - - # Handle nested `Model`s, etc. + cfg = get_all_argument_values(created_obj.__init__, *args, **kwargs) cfg = traverse_and_apply(cfg, _replace_model_instance_with_config) - # Add `ModelConfig` as member variables - self._config = ModelConfig( - class_name=str(self.__class__.__name__), + # Store config in + created_obj._config = ModelConfig( + class_name=str(created_obj.__class__.__name__), arguments=dict(**cfg), ) - - return ret - - return wrapper + return created_obj From 682213caba71602aac2c7b3a1a37924790d5c0ac Mon Sep 17 00:00:00 2001 From: AMHermansen Date: Wed, 13 Sep 2023 11:21:48 +0200 Subject: [PATCH 02/26] Remove usage of save_model_config and save_dataset_config --- src/graphnet/data/dataset/dataset.py | 11 ++++++++--- src/graphnet/models/coarsening.py | 3 --- src/graphnet/models/detector/detector.py | 2 -- src/graphnet/models/gnn/convnet.py | 2 -- src/graphnet/models/gnn/dynedge.py | 2 -- src/graphnet/models/gnn/dynedge_jinst.py | 2 -- src/graphnet/models/gnn/dynedge_kaggle_tito.py | 2 -- src/graphnet/models/gnn/gnn.py | 2 -- src/graphnet/models/graphs/edges/edges.py | 4 ---- src/graphnet/models/graphs/graph_definition.py | 3 --- src/graphnet/models/graphs/graphs.py | 2 -- src/graphnet/models/graphs/nodes/nodes.py | 2 -- src/graphnet/models/model.py | 10 ++++++++-- src/graphnet/models/standard_model.py | 2 -- src/graphnet/models/task/task.py | 3 --- src/graphnet/training/loss_functions.py | 3 --- src/graphnet/utilities/config/__init__.py | 14 ++++++++++++-- src/graphnet/utilities/config/dataset_config.py | 14 +++++++++++++- src/graphnet/utilities/config/model_config.py | 13 +++++++++++++ 19 files changed, 54 insertions(+), 42 deletions(-) diff --git a/src/graphnet/data/dataset/dataset.py b/src/graphnet/data/dataset/dataset.py index c1f785bc9..c1e4aad16 100644 --- a/src/graphnet/data/dataset/dataset.py +++ b/src/graphnet/data/dataset/dataset.py @@ -27,7 +27,7 @@ from graphnet.utilities.config import ( Configurable, DatasetConfig, - save_dataset_config, + DatasetConfigSaverABCMeta, ) from graphnet.utilities.config.parsing import traverse_and_apply from graphnet.utilities.logging import Logger @@ -85,7 +85,13 @@ def parse_graph_definition(cfg: dict) -> GraphDefinition: return graph_definition -class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): +class Dataset( + Logger, + Configurable, + torch.utils.data.Dataset, + ABC, + metaclass=DatasetConfigSaverABCMeta, +): """Base Dataset class for reading from any intermediate file format.""" # Class method(s) @@ -188,7 +194,6 @@ def _resolve_graphnet_paths( .replace("${GRAPHNET}", GRAPHNET_ROOT_DIR) ) - @save_dataset_config def __init__( self, path: Union[str, List[str]], diff --git a/src/graphnet/models/coarsening.py b/src/graphnet/models/coarsening.py index 68eab50b9..d40f0c009 100644 --- a/src/graphnet/models/coarsening.py +++ b/src/graphnet/models/coarsening.py @@ -22,7 +22,6 @@ std_pool_x, ) from graphnet.models import Model -from graphnet.utilities.config import save_model_config # Utility method(s) from torch_geometric.utils import degree @@ -63,7 +62,6 @@ class Coarsening(Model): "sum": (sum_pool, sum_pool_x), } - @save_model_config def __init__( self, reduce: str = "avg", @@ -198,7 +196,6 @@ def _add_inc_dict(self, original: Data, pooled: Data) -> Data: class AttributeCoarsening(Coarsening): """Coarsen pulses based on specified attributes.""" - @save_model_config def __init__( self, attributes: List[str], diff --git a/src/graphnet/models/detector/detector.py b/src/graphnet/models/detector/detector.py index e1b1cc6ef..a7fb25f1d 100644 --- a/src/graphnet/models/detector/detector.py +++ b/src/graphnet/models/detector/detector.py @@ -8,13 +8,11 @@ from graphnet.models import Model from graphnet.utilities.decorators import final -from graphnet.utilities.config import save_model_config class Detector(Model): """Base class for all detector-specific read-ins in graphnet.""" - @save_model_config def __init__(self) -> None: """Construct `Detector`.""" # Base class constructor diff --git a/src/graphnet/models/gnn/convnet.py b/src/graphnet/models/gnn/convnet.py index 9c03c96f7..dcffd0c50 100644 --- a/src/graphnet/models/gnn/convnet.py +++ b/src/graphnet/models/gnn/convnet.py @@ -10,14 +10,12 @@ from torch_geometric.nn import TAGConv, global_add_pool, global_max_pool from torch_geometric.data import Data -from graphnet.utilities.config import save_model_config from graphnet.models.gnn.gnn import GNN class ConvNet(GNN): """ConvNet (convolutional network) model.""" - @save_model_config def __init__( self, nb_inputs: int, diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py index 4e9e07b65..9ea93f9ce 100644 --- a/src/graphnet/models/gnn/dynedge.py +++ b/src/graphnet/models/gnn/dynedge.py @@ -7,7 +7,6 @@ from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum from graphnet.models.components.layers import DynEdgeConv -from graphnet.utilities.config import save_model_config from graphnet.models.gnn.gnn import GNN from graphnet.models.utils import calculate_xyzt_homophily @@ -22,7 +21,6 @@ class DynEdge(GNN): """DynEdge (dynamical edge convolutional) model.""" - @save_model_config def __init__( self, nb_inputs: int, diff --git a/src/graphnet/models/gnn/dynedge_jinst.py b/src/graphnet/models/gnn/dynedge_jinst.py index 36a0f1303..23c630fa9 100644 --- a/src/graphnet/models/gnn/dynedge_jinst.py +++ b/src/graphnet/models/gnn/dynedge_jinst.py @@ -10,7 +10,6 @@ from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum from graphnet.models.components.layers import DynEdgeConv -from graphnet.utilities.config import save_model_config from graphnet.models.gnn.gnn import GNN from graphnet.models.utils import calculate_xyzt_homophily @@ -18,7 +17,6 @@ class DynEdgeJINST(GNN): """DynEdge (dynamical edge convolutional) model used in [2209.03042].""" - @save_model_config def __init__( self, nb_inputs: int, diff --git a/src/graphnet/models/gnn/dynedge_kaggle_tito.py b/src/graphnet/models/gnn/dynedge_kaggle_tito.py index 2e07a4e72..d3196dd30 100644 --- a/src/graphnet/models/gnn/dynedge_kaggle_tito.py +++ b/src/graphnet/models/gnn/dynedge_kaggle_tito.py @@ -18,7 +18,6 @@ from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum from graphnet.models.components.layers import DynTrans -from graphnet.utilities.config import save_model_config from graphnet.models.gnn.gnn import GNN from graphnet.models.utils import calculate_xyzt_homophily @@ -33,7 +32,6 @@ class DynEdgeTITO(GNN): """DynEdge (dynamical edge convolutional) model.""" - @save_model_config def __init__( self, nb_inputs: int, diff --git a/src/graphnet/models/gnn/gnn.py b/src/graphnet/models/gnn/gnn.py index de155cb4a..5fd933d84 100644 --- a/src/graphnet/models/gnn/gnn.py +++ b/src/graphnet/models/gnn/gnn.py @@ -6,13 +6,11 @@ from torch_geometric.data import Data from graphnet.models import Model -from graphnet.utilities.config import save_model_config class GNN(Model): """Base class for all core GNN models in graphnet.""" - @save_model_config def __init__(self, nb_inputs: int, nb_outputs: int) -> None: """Construct `GNN`.""" # Base class constructor diff --git a/src/graphnet/models/graphs/edges/edges.py b/src/graphnet/models/graphs/edges/edges.py index 28507058b..cb9bf9112 100644 --- a/src/graphnet/models/graphs/edges/edges.py +++ b/src/graphnet/models/graphs/edges/edges.py @@ -7,7 +7,6 @@ from torch_geometric.nn import knn_graph, radius_graph from torch_geometric.data import Data -from graphnet.utilities.config import save_model_config from graphnet.models.utils import calculate_distance_matrix from graphnet.models import Model @@ -48,7 +47,6 @@ def _construct_edges(self, graph: Data) -> Data: class KNNEdges(EdgeDefinition): # pylint: disable=too-few-public-methods """Builds edges from the k-nearest neighbours.""" - @save_model_config def __init__( self, nb_nearest_neighbours: int, @@ -85,7 +83,6 @@ def _construct_edges(self, graph: Data) -> Data: class RadialEdges(EdgeDefinition): """Builds graph from a sphere of chosen radius centred at each node.""" - @save_model_config def __init__( self, radius: float, @@ -126,7 +123,6 @@ class EuclideanEdges(EdgeDefinition): # pylint: disable=too-few-public-methods See https://arxiv.org/pdf/1809.06166.pdf. """ - @save_model_config def __init__( self, sigma: float, diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 48394ab73..6f41f739d 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -11,8 +11,6 @@ from torch_geometric.data import Data import numpy as np -from graphnet.utilities.config import save_model_config - from graphnet.models.detector import Detector from .edges import EdgeDefinition from .nodes import NodeDefinition @@ -22,7 +20,6 @@ class GraphDefinition(Model): """An Abstract class to create graph definitions from.""" - @save_model_config def __init__( self, detector: Detector, diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index dc2ded022..1cae33a5d 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -3,7 +3,6 @@ from typing import List, Optional import torch -from graphnet.utilities.config import save_model_config from .graph_definition import GraphDefinition from graphnet.models.detector import Detector from graphnet.models.graphs.edges import EdgeDefinition, KNNEdges @@ -13,7 +12,6 @@ class KNNGraph(GraphDefinition): """A Graph representation where Edges are drawn to nearest neighbours.""" - @save_model_config def __init__( self, detector: Detector, diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py index 6b3443e0c..ce539ee80 100644 --- a/src/graphnet/models/graphs/nodes/nodes.py +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -7,14 +7,12 @@ from torch_geometric.data import Data from graphnet.utilities.decorators import final -from graphnet.utilities.config import save_model_config from graphnet.models import Model class NodeDefinition(Model): # pylint: disable=too-few-public-methods """Base class for graph building.""" - @save_model_config def __init__(self) -> None: """Construct `Detector`.""" # Base class constructor diff --git a/src/graphnet/models/model.py b/src/graphnet/models/model.py index 193746919..7c34f0152 100644 --- a/src/graphnet/models/model.py +++ b/src/graphnet/models/model.py @@ -18,11 +18,17 @@ from torch_geometric.data import Data from graphnet.utilities.logging import Logger -from graphnet.utilities.config import Configurable, ModelConfig +from graphnet.utilities.config import ( + Configurable, + ModelConfig, + ModelConfigSaverABC, +) from graphnet.training.callbacks import ProgressBar -class Model(Logger, Configurable, LightningModule, ABC): +class Model( + Logger, Configurable, LightningModule, ABC, metaclass=ModelConfigSaverABC +): """Base class for all models in graphnet.""" @abstractmethod diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index 1d439133f..0f4f6895b 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -10,7 +10,6 @@ from torch_geometric.data import Data import pandas as pd -from graphnet.utilities.config import save_model_config from graphnet.models.graphs import GraphDefinition from graphnet.models.gnn.gnn import GNN from graphnet.models.model import Model @@ -24,7 +23,6 @@ class StandardModel(Model): model (detector read-in, GNN architecture, and task-specific read-outs). """ - @save_model_config def __init__( self, *, diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index 094071a7c..0d7379f00 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -15,7 +15,6 @@ from graphnet.training.loss_functions import LossFunction # type: ignore[attr-defined] from graphnet.models import Model -from graphnet.utilities.config import save_model_config from graphnet.utilities.decorators import final @@ -39,7 +38,6 @@ def default_prediction_labels(self) -> List[str]: """Return default prediction labels.""" return self._default_prediction_labels - @save_model_config def __init__( self, *, @@ -264,7 +262,6 @@ def _validate_and_set_transforms( class IdentityTask(Task): """Identity, or trivial, task.""" - @save_model_config def __init__( self, nb_outputs: int, diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index 740f0b912..624a5fa53 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -19,7 +19,6 @@ softplus, ) -from graphnet.utilities.config import save_model_config from graphnet.models.model import Model from graphnet.utilities.decorators import final @@ -27,7 +26,6 @@ class LossFunction(Model): """Base class for loss functions in `graphnet`.""" - @save_model_config def __init__(self, **kwargs: Any) -> None: """Construct `LossFunction`, saving model config.""" super().__init__(**kwargs) @@ -120,7 +118,6 @@ class CrossEntropyLoss(LossFunction): (0, num_classes - 1). """ - @save_model_config def __init__( self, options: Union[int, List[Any], Dict[Any, int]], diff --git a/src/graphnet/utilities/config/__init__.py b/src/graphnet/utilities/config/__init__.py index 426eae788..b53feb28e 100644 --- a/src/graphnet/utilities/config/__init__.py +++ b/src/graphnet/utilities/config/__init__.py @@ -1,6 +1,16 @@ """Modules for configuration files for use across `graphnet`.""" from .configurable import Configurable -from .dataset_config import DatasetConfig, DatasetConfigSaverMeta -from .model_config import ModelConfig, ModelConfigSaverMeta +from .dataset_config import ( + DatasetConfig, + DatasetConfigSaverMeta, + DatasetConfigSaverABCMeta, + DatasetConfigSaver, +) +from .model_config import ( + ModelConfig, + ModelConfigSaverMeta, + ModelConfigSaver, + ModelConfigSaverABC, +) from .training_config import TrainingConfig diff --git a/src/graphnet/utilities/config/dataset_config.py b/src/graphnet/utilities/config/dataset_config.py index b48459315..bb2dd1595 100644 --- a/src/graphnet/utilities/config/dataset_config.py +++ b/src/graphnet/utilities/config/dataset_config.py @@ -1,5 +1,5 @@ """Config classes for the `graphnet.data.dataset` module.""" - +from abc import ABCMeta from functools import wraps from typing import ( TYPE_CHECKING, @@ -208,3 +208,15 @@ def _replace_model_instance_with_config( arguments=dict(**cfg), ) return created_obj + + +class DatasetConfigSaverABCMeta(DatasetConfigSaverMeta, ABCMeta): + """Common interface between DatasetConfigSaver and ABC Metaclasses.""" + + pass + + +class DatasetConfigSaver(metaclass=DatasetConfigSaverMeta): + """Baseclass for DatasetConfig saving.""" + + pass diff --git a/src/graphnet/utilities/config/model_config.py b/src/graphnet/utilities/config/model_config.py index e811a1ab0..32978c4e1 100644 --- a/src/graphnet/utilities/config/model_config.py +++ b/src/graphnet/utilities/config/model_config.py @@ -1,4 +1,5 @@ """Config classes for the `graphnet.models` module.""" +from abc import ABCMeta from functools import wraps import inspect import re @@ -278,3 +279,15 @@ def _replace_model_instance_with_config( arguments=dict(**cfg), ) return created_obj + + +class ModelConfigSaverABC(ModelConfigSaverMeta, ABCMeta): + """Common interface between ModelConfigSaver and ABC Metaclasses.""" + + pass + + +class ModelConfigSaver(metaclass=ModelConfigSaverMeta): + """Base class for ModelConfig saving.""" + + pass From e4c06fdf085eb7f011b488650a533ca2773783dc Mon Sep 17 00:00:00 2001 From: AMHermansen Date: Wed, 13 Sep 2023 11:27:53 +0200 Subject: [PATCH 03/26] Remove redundant config saving baseclasses. --- src/graphnet/utilities/config/dataset_config.py | 15 +++++---------- src/graphnet/utilities/config/model_config.py | 6 ------ 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/src/graphnet/utilities/config/dataset_config.py b/src/graphnet/utilities/config/dataset_config.py index bb2dd1595..145776e06 100644 --- a/src/graphnet/utilities/config/dataset_config.py +++ b/src/graphnet/utilities/config/dataset_config.py @@ -189,9 +189,13 @@ def _replace_model_instance_with_config( ) -> Union[ModelConfig, Any]: """Replace `Model` instances in `obj` with their `ModelConfig`.""" from graphnet.models import Model + import torch if isinstance(obj, Model): return obj.config + + if isinstance(obj, torch.dtype): + return obj.__str__() else: return obj @@ -203,10 +207,7 @@ def _replace_model_instance_with_config( cfg = traverse_and_apply(cfg, _replace_model_instance_with_config) # Store config in - created_obj._config = DatasetConfig( - class_name=str(created_obj.__class__.__name__), - arguments=dict(**cfg), - ) + created_obj._config = DatasetConfig(**cfg) return created_obj @@ -214,9 +215,3 @@ class DatasetConfigSaverABCMeta(DatasetConfigSaverMeta, ABCMeta): """Common interface between DatasetConfigSaver and ABC Metaclasses.""" pass - - -class DatasetConfigSaver(metaclass=DatasetConfigSaverMeta): - """Baseclass for DatasetConfig saving.""" - - pass diff --git a/src/graphnet/utilities/config/model_config.py b/src/graphnet/utilities/config/model_config.py index 32978c4e1..3fdc7d307 100644 --- a/src/graphnet/utilities/config/model_config.py +++ b/src/graphnet/utilities/config/model_config.py @@ -285,9 +285,3 @@ class ModelConfigSaverABC(ModelConfigSaverMeta, ABCMeta): """Common interface between ModelConfigSaver and ABC Metaclasses.""" pass - - -class ModelConfigSaver(metaclass=ModelConfigSaverMeta): - """Base class for ModelConfig saving.""" - - pass From 63a87153ed385575bb2cc9f54ac11eced35fcc6d Mon Sep 17 00:00:00 2001 From: AMHermansen Date: Wed, 13 Sep 2023 11:29:35 +0200 Subject: [PATCH 04/26] Remove redundant config saving baseclasses. --- src/graphnet/utilities/config/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/graphnet/utilities/config/__init__.py b/src/graphnet/utilities/config/__init__.py index b53feb28e..15df77a24 100644 --- a/src/graphnet/utilities/config/__init__.py +++ b/src/graphnet/utilities/config/__init__.py @@ -5,12 +5,10 @@ DatasetConfig, DatasetConfigSaverMeta, DatasetConfigSaverABCMeta, - DatasetConfigSaver, ) from .model_config import ( ModelConfig, ModelConfigSaverMeta, - ModelConfigSaver, ModelConfigSaverABC, ) from .training_config import TrainingConfig From 9321927ee658225fb76b556849aac88555764879 Mon Sep 17 00:00:00 2001 From: AMHermansen Date: Thu, 14 Sep 2023 11:17:31 +0200 Subject: [PATCH 05/26] Reintroduce save_model_config and save_dataset_config but added deprecation warning. To not break backwards compatibility. --- src/graphnet/utilities/config/__init__.py | 2 + .../utilities/config/dataset_config.py | 46 ++++++++++++++++++- src/graphnet/utilities/config/model_config.py | 45 +++++++++++++++++- 3 files changed, 91 insertions(+), 2 deletions(-) diff --git a/src/graphnet/utilities/config/__init__.py b/src/graphnet/utilities/config/__init__.py index 15df77a24..1520ca68d 100644 --- a/src/graphnet/utilities/config/__init__.py +++ b/src/graphnet/utilities/config/__init__.py @@ -5,10 +5,12 @@ DatasetConfig, DatasetConfigSaverMeta, DatasetConfigSaverABCMeta, + save_dataset_config, ) from .model_config import ( ModelConfig, ModelConfigSaverMeta, ModelConfigSaverABC, + save_model_config, ) from .training_config import TrainingConfig diff --git a/src/graphnet/utilities/config/dataset_config.py b/src/graphnet/utilities/config/dataset_config.py index 145776e06..57739b667 100644 --- a/src/graphnet/utilities/config/dataset_config.py +++ b/src/graphnet/utilities/config/dataset_config.py @@ -1,4 +1,5 @@ """Config classes for the `graphnet.data.dataset` module.""" +import warnings from abc import ABCMeta from functools import wraps from typing import ( @@ -178,11 +179,54 @@ def _parse_torch(self, obj: Any) -> Any: return obj +def save_dataset_config(init_fn: Callable) -> Callable: + """Save the arguments to `__init__` functions as member `DatasetConfig`.""" + warnings.warn( + "Warning: `save_dataset_config` is deprecated. Config saving " + "is now done automatically, for all classes inheriting from Dataset", + DeprecationWarning, + ) + + def _replace_model_instance_with_config( + obj: Union["Model", Any] + ) -> Union[ModelConfig, Any]: + """Replace `Model` instances in `obj` with their `ModelConfig`.""" + from graphnet.models import Model + import torch + + if isinstance(obj, Model): + return obj.config + + if isinstance(obj, torch.dtype): + return obj.__str__() + + else: + return obj + + @wraps(init_fn) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + """Set `DatasetConfig` after calling `init_fn`.""" + # Call wrapped method + ret = init_fn(self, *args, **kwargs) + + # Get all argument values, including defaults + cfg = get_all_argument_values(init_fn, *args, **kwargs) + + # Handle nested `Model`s, etc. + cfg = traverse_and_apply(cfg, _replace_model_instance_with_config) + # Add `DatasetConfig` as member variables + self._config = DatasetConfig(**cfg) + + return ret + + return wrapper + + class DatasetConfigSaverMeta(type): """Metaclass for `DatasetConfig` that saves the config after `__init__`.""" def __call__(cls: Any, *args: Any, **kwargs: Any) -> object: - """Catch object construction and save config after `__init__`.""" + """Catch object after construction and save config.""" def _replace_model_instance_with_config( obj: Union["Model", Any] diff --git a/src/graphnet/utilities/config/model_config.py b/src/graphnet/utilities/config/model_config.py index 3fdc7d307..23b4c9b58 100644 --- a/src/graphnet/utilities/config/model_config.py +++ b/src/graphnet/utilities/config/model_config.py @@ -3,6 +3,7 @@ from functools import wraps import inspect import re +import warnings from typing import ( TYPE_CHECKING, Any, @@ -249,6 +250,48 @@ def as_dict(self) -> Dict[str, Dict[str, Any]]: return {self.__class__.__name__: config_dict} +def save_model_config(init_fn: Callable) -> Callable: + """Save the arguments to `__init__` functions as a member `ModelConfig`.""" + warnings.warn( + "Warning: `save_model_config` is deprecated. Config saving is" + "now done automatically for all classes inheriting from Model", + DeprecationWarning, + ) + + def _replace_model_instance_with_config( + obj: Union["Model", Any] + ) -> Union[ModelConfig, Any]: + """Replace `Model` instances in `obj` with their `ModelConfig`.""" + from graphnet.models import Model + + if isinstance(obj, Model): + return obj.config + else: + return obj + + @wraps(init_fn) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + """Set `ModelConfig` after calling `init_fn`.""" + # Call wrapped method + ret = init_fn(self, *args, **kwargs) + + # Get all argument values, including defaults + cfg = get_all_argument_values(init_fn, *args, **kwargs) + + # Handle nested `Model`s, etc. + cfg = traverse_and_apply(cfg, _replace_model_instance_with_config) + + # Add `ModelConfig` as member variables + self._config = ModelConfig( + class_name=str(self.__class__.__name__), + arguments=dict(**cfg), + ) + + return ret + + return wrapper + + class ModelConfigSaverMeta(type): """Metaclass for saving `ModelConfig` to `Model` instances.""" @@ -275,7 +318,7 @@ def _replace_model_instance_with_config( # Store config in created_obj._config = ModelConfig( - class_name=str(created_obj.__class__.__name__), + class_name=str(cls.__name__), arguments=dict(**cfg), ) return created_obj From 4af2872ad67abca7d99b50bfcee5c09892441827 Mon Sep 17 00:00:00 2001 From: amhermansen Date: Tue, 19 Sep 2023 18:00:15 +0200 Subject: [PATCH 06/26] Fixed typehints for make_(train_validation)_dataloader --- src/graphnet/training/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index f7d5249f9..1bb2d89e1 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -32,7 +32,7 @@ def collate_fn(graphs: List[Data]) -> Batch: def make_dataloader( db: str, pulsemaps: Union[str, List[str]], - graph_definition: Optional[GraphDefinition], + graph_definition: GraphDefinition, features: List[str], truth: List[str], *, @@ -92,7 +92,7 @@ def make_dataloader( # @TODO: Remove in favour of DataLoader{,.from_dataset_config} def make_train_validation_dataloader( db: str, - graph_definition: Optional[GraphDefinition], + graph_definition: GraphDefinition, selection: Optional[List[int]], pulsemaps: Union[str, List[str]], features: List[str], From 368b9a8c7c9771a03a8e5d78fa8b13851abc9757 Mon Sep 17 00:00:00 2001 From: amhermansen Date: Tue, 19 Sep 2023 18:30:17 +0200 Subject: [PATCH 07/26] Fixed typehints for make_(train_validation)_dataloader --- src/graphnet/training/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index 1bb2d89e1..1befa9f77 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -22,7 +22,7 @@ def collate_fn(graphs: List[Data]) -> Batch: """Remove graphs with less than two DOM hits. - Should not occur in "production. + Should not occur in "productio"n. """ graphs = [g for g in graphs if g.n_pulses > 1] return Batch.from_data_list(graphs) From 396eb77a8dd5e90d00f5dfe1c369788c53a17b33 Mon Sep 17 00:00:00 2001 From: Andreas Michael Hermansen <97125645+AMHermansen@users.noreply.github.com> Date: Tue, 19 Sep 2023 18:36:09 +0200 Subject: [PATCH 08/26] Update utils.py fixed typo --- src/graphnet/training/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index 1befa9f77..ec3b4c461 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -22,7 +22,7 @@ def collate_fn(graphs: List[Data]) -> Batch: """Remove graphs with less than two DOM hits. - Should not occur in "productio"n. + Should not occur in "production". """ graphs = [g for g in graphs if g.n_pulses > 1] return Batch.from_data_list(graphs) From b846b4067fccbbdc0a3b236fbf750cecf3f76e0a Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 09:39:13 +0200 Subject: [PATCH 09/26] fix example 02-02 --- .../02_data/02_plot_feature_distributions.py | 38 +++---------------- 1 file changed, 5 insertions(+), 33 deletions(-) diff --git a/examples/02_data/02_plot_feature_distributions.py b/examples/02_data/02_plot_feature_distributions.py index 88ffef576..b46be0623 100644 --- a/examples/02_data/02_plot_feature_distributions.py +++ b/examples/02_data/02_plot_feature_distributions.py @@ -1,4 +1,4 @@ -"""Example of plotting feature distributions from SQLite database.""" +"""Example of visualization of input data from a configured Dataset.""" import os.path @@ -8,8 +8,6 @@ from graphnet.constants import CONFIG_DIR from graphnet.data.dataset import Dataset -from graphnet.models.detector.icecube import IceCubeDeepCore -from graphnet.models.graph_builders import KNNGraphBuilder from graphnet.utilities.logging import Logger from graphnet.utilities.argparse import ArgumentParser @@ -27,46 +25,20 @@ def main() -> None: assert isinstance(dataset, Dataset) features = dataset._features[1:] - # Building model - detector = IceCubeDeepCore( - graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8), - ) - # Get feature matrix - x_original_list = [] x_preprocessed_list = [] for batch in tqdm(dataset, colour="green"): - x_original_list.append(batch.x.numpy()) - x_preprocessed_list.append(detector(batch).x.numpy()) + x_preprocessed_list.append(batch.x.numpy()) - x_original = np.concatenate(x_original_list, axis=0) x_preprocessed = np.concatenate(x_preprocessed_list, axis=0) - - logger.info(f"Number of NaNs: {np.sum(np.isnan(x_original))}") - logger.info(f"Number of infs: {np.sum(np.isinf(x_original))}") + logger.info(f"Number of NaNs: {np.sum(np.isnan(x_preprocessed))}") + logger.info(f"Number of infs: {np.sum(np.isinf(x_preprocessed))}") # Plot feature distributions - nb_features_original = x_original.shape[1] nb_features_preprocessed = x_preprocessed.shape[1] dim = int(np.ceil(np.sqrt(nb_features_preprocessed))) axis_size = 4 - bins = 100 - - # -- Original - fig, axes = plt.subplots( - dim, dim, figsize=(dim * axis_size, dim * axis_size) - ) - for ix, ax in enumerate(axes.ravel()[:nb_features_original]): - ax.hist(x_original[:, ix], bins=bins) - ax.set_xlabel( - f"x{ix}: {features[ix] if ix < len(features) else 'N/A'}" - ) - ax.set_yscale("log") - - fig.tight_layout - figure_name_original = "feature_distribution_original.png" - fig.savefig(figure_name_original) - logger.info(f"Figure written to {figure_name_original}") + bins = 50 # -- Preprocessed fig, axes = plt.subplots( From ae226e0b0fde149b0fad3539f940302a160ecd08 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 10:13:05 +0200 Subject: [PATCH 10/26] default arguments, fix 02-01 --- examples/02_data/01_read_dataset.py | 21 ++++++++++++------- .../models/graphs/graph_definition.py | 6 +++--- src/graphnet/models/graphs/graphs.py | 4 ++-- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/examples/02_data/01_read_dataset.py b/examples/02_data/01_read_dataset.py index 302529050..be913e190 100644 --- a/examples/02_data/01_read_dataset.py +++ b/examples/02_data/01_read_dataset.py @@ -17,7 +17,10 @@ from graphnet.data.dataset import ParquetDataset from graphnet.utilities.argparse import ArgumentParser from graphnet.utilities.logging import Logger - +from graphnet.models.graphs import KNNGraph +from graphnet.models.detector.icecube import ( + IceCubeDeepCore, +) DATASET_CLASS = { "sqlite": SQLiteDataset, @@ -44,6 +47,9 @@ def main(backend: str) -> None: num_workers = 30 wait_time = 0.00 # sec. + # Define graph representation + graph_definition = KNNGraph(detector=IceCubeDeepCore()) + for table in [pulsemap, truth_table]: # Get column names from backend if backend == "sqlite": @@ -62,15 +68,16 @@ def main(backend: str) -> None: # Common variables dataset = DATASET_CLASS[backend]( - path, - pulsemap, - features, - truth, + path=path, + pulsemaps=pulsemap, + features=features, + truth=truth, truth_table=truth_table, + graph_definition=graph_definition, ) assert isinstance(dataset, Dataset) - logger.info(dataset[1]) + logger.info(str(dataset[1])) logger.info(dataset[1].x) if backend == "sqlite": assert isinstance(dataset, SQLiteDataset) @@ -92,7 +99,7 @@ def main(backend: str) -> None: for batch in tqdm(dataloader, unit=" batches", colour="green"): time.sleep(wait_time) - logger.info(batch) + logger.info(str(batch)) logger.info(batch.size()) logger.info(batch.num_graphs) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 6f41f739d..0a87a301e 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -13,7 +13,7 @@ from graphnet.models.detector import Detector from .edges import EdgeDefinition -from .nodes import NodeDefinition +from .nodes import NodeDefinition, NodesAsPulses from graphnet.models import Model @@ -23,7 +23,7 @@ class GraphDefinition(Model): def __init__( self, detector: Detector, - node_definition: NodeDefinition, + node_definition: NodeDefinition = NodesAsPulses(), edge_definition: Optional[EdgeDefinition] = None, node_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, @@ -39,7 +39,7 @@ def __init__( Args: detector: The corresponding ´Detector´ representing the data. - node_definition: Definition of nodes. + node_definition: Definition of nodes. Defaults to NodesAsPulses. edge_definition: Definition of edges. Defaults to None. node_feature_names: Names of node feature columns. Defaults to None dtype: data type used for node features. e.g. ´torch.float´ diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index 1cae33a5d..a48b33a0d 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -6,7 +6,7 @@ from .graph_definition import GraphDefinition from graphnet.models.detector import Detector from graphnet.models.graphs.edges import EdgeDefinition, KNNEdges -from graphnet.models.graphs.nodes import NodeDefinition +from graphnet.models.graphs.nodes import NodeDefinition, NodesAsPulses class KNNGraph(GraphDefinition): @@ -15,7 +15,7 @@ class KNNGraph(GraphDefinition): def __init__( self, detector: Detector, - node_definition: NodeDefinition, + node_definition: NodeDefinition = NodesAsPulses(), node_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, nb_nearest_neighbours: int = 8, From 8e04af206bb6a287fb8ca28dfd377e3de099a386 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 10:46:47 +0200 Subject: [PATCH 11/26] tito_example update --- .../04_training/04_train_tito_model_without_configs.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/examples/04_training/04_train_tito_model_without_configs.py b/examples/04_training/04_train_tito_model_without_configs.py index 858aa45f4..40ab926e3 100644 --- a/examples/04_training/04_train_tito_model_without_configs.py +++ b/examples/04_training/04_train_tito_model_without_configs.py @@ -28,7 +28,6 @@ # Constants features = FEATURES.PROMETHEUS truth = TRUTH.PROMETHEUS -DYNTRANS_LAYER_SIZES = [(256, 256), (256, 256), (256, 256)] def main( @@ -76,12 +75,7 @@ def main( }, } - graph_definition = KNNGraph( - detector=Prometheus(), - node_definition=NodesAsPulses(), - nb_nearest_neighbours=8, - node_feature_names=features, - ) + graph_definition = KNNGraph(detector=Prometheus()) archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_tito_model") run_name = "dynedgeTITO_{}_example".format(config["target"]) if wandb: @@ -115,7 +109,6 @@ def main( gnn = DynEdgeTITO( nb_inputs=graph_definition.nb_outputs, global_pooling_schemes=["max"], - dyntrans_layer_sizes=DYNTRANS_LAYER_SIZES, ) task = DirectionReconstructionWithKappa( hidden_size=gnn.nb_outputs, From 15a14f7071d8e1b5927c475be4d67999cbef01fa Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 11:33:23 +0200 Subject: [PATCH 12/26] Polish examples --- .../04_training/02_train_model_without_configs.py | 15 ++++++++------- .../04_train_tito_model_without_configs.py | 8 +++++++- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/examples/04_training/02_train_model_without_configs.py b/examples/04_training/02_train_model_without_configs.py index 6d9c5746e..6a8bf2f15 100644 --- a/examples/04_training/02_train_model_without_configs.py +++ b/examples/04_training/02_train_model_without_configs.py @@ -79,12 +79,7 @@ def main( wandb_logger.experiment.config.update(config) # Define graph representation - graph_definition = KNNGraph( - detector=Prometheus(), - node_definition=NodesAsPulses(), - nb_nearest_neighbours=8, - node_feature_names=features, - ) + graph_definition = KNNGraph(detector=Prometheus()) ( training_dataloader, @@ -166,10 +161,16 @@ def main( logger.info(f"Writing results to {path}") os.makedirs(path, exist_ok=True) + # Save results as .csv results.to_csv(f"{path}/results.csv") - model.save_state_dict(f"{path}/state_dict.pth") + + # Save full model (including weights) to .pth file - Not version proof model.save(f"{path}/model.pth") + # Save model config and state dict - Version safe save method. + model.save_state_dict(f"{path}/state_dict.pth") + model.save_config(f"{path}/model_config.yml") + if __name__ == "__main__": diff --git a/examples/04_training/04_train_tito_model_without_configs.py b/examples/04_training/04_train_tito_model_without_configs.py index 40ab926e3..60b19d392 100644 --- a/examples/04_training/04_train_tito_model_without_configs.py +++ b/examples/04_training/04_train_tito_model_without_configs.py @@ -175,10 +175,16 @@ def main( logger.info(f"Writing results to {path}") os.makedirs(path, exist_ok=True) + # Save results as .csv results.to_csv(f"{path}/results.csv") - model.save_state_dict(f"{path}/state_dict.pth") + + # Save full model (including weights) to .pth file - Not version proof model.save(f"{path}/model.pth") + # Save model config and state dict - Version safe save method. + model.save_state_dict(f"{path}/state_dict.pth") + model.save_config(f"{path}/model_config.yml") + if __name__ == "__main__": From a11014c220691f8012996198d90fd0d5d474c1fa Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 11:36:06 +0200 Subject: [PATCH 13/26] delete shell script example --- .../04_training/03_train_multiple_models.sh | 62 ------------------- .../04_train_tito_model_without_configs.py | 1 - 2 files changed, 63 deletions(-) delete mode 100644 examples/04_training/03_train_multiple_models.sh diff --git a/examples/04_training/03_train_multiple_models.sh b/examples/04_training/03_train_multiple_models.sh deleted file mode 100644 index 931b574ca..000000000 --- a/examples/04_training/03_train_multiple_models.sh +++ /dev/null @@ -1,62 +0,0 @@ -#!/bin/bash - -#### This script enables the user to run multiple trainings in sequence on the same database but for different model configs. -# To execute this file, copy the file path and write in the terminal; $ bash - - -# execution of bash file in same directory as the script -bash_directory=$(dirname -- "$(readlink -f "${BASH_SOURCE}")") - -## Global; applies to all models -# path to dataset configuration file in the GraphNeT directory -dataset_config=$(realpath "$bash_directory/../../configs/datasets/training_example_data_sqlite.yml") -# what GPU to use; more information can be gained with the module nvitop -gpus=0 -# the maximum number of epochs; if used, this greatly affect learning rate scheduling -max_epochs=5 -# early stopping threshold -early_stopping_patience=5 -# events in a batch -batch_size=16 -# number of CPUs to use -num_workers=2 - -## Model dependent; applies to each model in sequence -# path to model files in the GraphNeT directory -model_directory=$(realpath "$bash_directory/../../configs/models") -# list of model configurations to train -declare -a model_configs=( - "${model_directory}/example_direction_reconstruction_model.yml" - "${model_directory}/example_energy_reconstruction_model.yml" - "${model_directory}/example_vertex_position_reconstruction_model.yml" -) - -# suffix ending on the created directory -declare -a suffixs=( - "direction" - "energy" - "position" -) - -# prediction name outputs per model -declare -a prediction_names=( - "zenith_pred zenith_kappa_pred azimuth_pred azimuth_kappa_pred" - "energy_pred" - "position_x_pred position_y_pred position_z_pred" -) - -for i in "${!model_configs[@]}"; do - echo "training iteration ${i} on ${model_configs[$i]} with output variables ${prediction_names[i][@]}" - python ${bash_directory}/01_train_model.py \ - --dataset-config ${dataset_config} \ - --model-config ${model_configs[$i]} \ - --gpus ${gpus} \ - --max-epochs ${max_epochs} \ - --early-stopping-patience ${early_stopping_patience} \ - --batch-size ${batch_size} \ - --num-workers ${num_workers} \ - --prediction-names ${prediction_names[i][@]} \ - --suffix ${suffixs[i]} - wait -done -echo "all trainings are done." \ No newline at end of file diff --git a/examples/04_training/04_train_tito_model_without_configs.py b/examples/04_training/04_train_tito_model_without_configs.py index 60b19d392..ee3d89760 100644 --- a/examples/04_training/04_train_tito_model_without_configs.py +++ b/examples/04_training/04_train_tito_model_without_configs.py @@ -14,7 +14,6 @@ from graphnet.models.detector.prometheus import Prometheus from graphnet.models.gnn import DynEdgeTITO from graphnet.models.graphs import KNNGraph -from graphnet.models.graphs.nodes import NodesAsPulses from graphnet.models.task.reconstruction import ( DirectionReconstructionWithKappa, ) From c840b44180ede370241c5727f48c2920e5a5342b Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 11:48:09 +0200 Subject: [PATCH 14/26] rename examples, update readme.md --- ...without_configs.py => 01_train_dynedge.py} | 0 ...hout_configs.py => 02_train_tito_model.py} | 0 ...del.py => 03_train_dynedge_from_config.py} | 7 ++++--- ... 04_train_multiclassifier_from_configs.py} | 2 +- examples/04_training/README.md | 20 +++++++++---------- 5 files changed, 15 insertions(+), 14 deletions(-) rename examples/04_training/{02_train_model_without_configs.py => 01_train_dynedge.py} (100%) rename examples/04_training/{04_train_tito_model_without_configs.py => 02_train_tito_model.py} (100%) rename examples/04_training/{01_train_model.py => 03_train_dynedge_from_config.py} (94%) rename examples/04_training/{03_train_classification_model.py => 04_train_multiclassifier_from_configs.py} (98%) diff --git a/examples/04_training/02_train_model_without_configs.py b/examples/04_training/01_train_dynedge.py similarity index 100% rename from examples/04_training/02_train_model_without_configs.py rename to examples/04_training/01_train_dynedge.py diff --git a/examples/04_training/04_train_tito_model_without_configs.py b/examples/04_training/02_train_tito_model.py similarity index 100% rename from examples/04_training/04_train_tito_model_without_configs.py rename to examples/04_training/02_train_tito_model.py diff --git a/examples/04_training/01_train_model.py b/examples/04_training/03_train_dynedge_from_config.py similarity index 94% rename from examples/04_training/01_train_model.py rename to examples/04_training/03_train_dynedge_from_config.py index 0250f6602..9a6df73dd 100644 --- a/examples/04_training/01_train_model.py +++ b/examples/04_training/03_train_dynedge_from_config.py @@ -1,4 +1,4 @@ -"""Simplified example of training Model.""" +"""Simplified example of training DynEdge from pre-defined config files.""" from typing import List, Optional import os @@ -46,7 +46,7 @@ def main( log_model=True, ) - # Build model + # Build model from pre-defined config file made from Model.save_config model_config = ModelConfig.load(model_config_path) model: StandardModel = StandardModel.from_config(model_config, trust=True) @@ -69,7 +69,8 @@ def main( archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model") run_name = "dynedge_{}_example".format("_".join(config.target)) - # Construct dataloaders + # Construct dataloaders from pre-defined dataset config files. + # i.e. from Dataset.save_config dataset_config = DatasetConfig.load(dataset_config_path) dataloaders = DataLoader.from_dataset_config( dataset_config, diff --git a/examples/04_training/03_train_classification_model.py b/examples/04_training/04_train_multiclassifier_from_configs.py similarity index 98% rename from examples/04_training/03_train_classification_model.py rename to examples/04_training/04_train_multiclassifier_from_configs.py index 65b1acf2a..6937b01b1 100644 --- a/examples/04_training/03_train_classification_model.py +++ b/examples/04_training/04_train_multiclassifier_from_configs.py @@ -1,4 +1,4 @@ -"""Simplified example of multi-class classification training Model.""" +"""Multi-class classification using DynEdge from pre-defined config files.""" import os from typing import List, Optional, Dict, Any diff --git a/examples/04_training/README.md b/examples/04_training/README.md index e60aef3ae..1849515be 100644 --- a/examples/04_training/README.md +++ b/examples/04_training/README.md @@ -2,44 +2,44 @@ This subfolder contains two main training scripts: -**`01_train_model.py`** Shows how to train a GNN on neutrino telescope data **using configuration files** to construct the dataset that loads the data and the model that is trained. This is the recommended way to configure standard dataset and models, as it is easier to ready and share than doing so in pure code. This example can be run using a few different models targeting different physics use cases. For instance, you can try running: +**`01_train_dynedge.py`** ** Shows how to train a GNN on neutrino telescope data **without configuration files,** i.e., by programatically constructing the dataset and model used. This is good for debugging and experimenting with different dataset settings and model configurations, as it is easier to build the model using the API than by writing configuration files from scratch. **This is our recommended way of getting started with the library**. For instance, try running: ```bash # Show the CLI -(graphnet) $ python examples/04_training/01_train_model.py --help +(graphnet) $ python examples/04_training/01_train_dynedge.py --help # Train energy regression model -(graphnet) $ python examples/04_training/01_train_model.py +(graphnet) $ python examples/04_training/01_train_dynedge.py # Same as above, as this is the default model config. (graphnet) $ python examples/04_training/01_train_model.py \ --model-config configs/models/example_energy_reconstruction_model.yml # Train using a single GPU -(graphnet) $ python examples/04_training/01_train_model.py --gpus 0 +(graphnet) $ python examples/04_training/01_train_dynedge.py --gpus 0 # Train using multiple GPUs -(graphnet) $ python examples/04_training/01_train_model.py --gpus 0 1 +(graphnet) $ python examples/04_training/01_train_dynedge.py --gpus 0 1 # Train a vertex position reconstruction model -(graphnet) $ python examples/04_training/01_train_model.py \ +(graphnet) $ python examples/04_training/01_train_dynedge.py \ --model-config configs/models/example_vertex_position_reconstruction_model.yml # Trains a direction (zenith, azimuth) reconstruction model. Note that the # chosen `Task` in the model config file also returns estimated "kappa" values, # i.e. inverse variance, for each predicted feature, meaning that we need to # manually specify the names of these. -(graphnet) $ python examples/04_training/01_train_model.py --gpus 0 \ +(graphnet) $ python examples/04_training/01_train_model_dynedge.py --gpus 0 \ --model-config configs/models/example_direction_reconstruction_model.yml \ --prediction-names zenith_pred zenith_kappa_pred azimuth_pred azimuth_kappa_pred ``` -**`02_train_model_without_configs.py`** Shows how to train a GNN on neutrino telescope data **without configuration files,** i.e., by programatically constructing the dataset and model used. This is good for debugging and experimenting with different dataset settings and model configurations, as it is easier to build the model using the API than by writing configuration files from scratch. For instance, try running: +**`02_train_model_dynedge_from_config.py** Shows how to train a GNN on neutrino telescope data **using configuration files** to construct the dataset that loads the data and the model that is trained. This is the recommended way to configure standard dataset and models, as it is easier to ready and share than doing so in pure code. This example can be run using a few different models targeting different physics use cases. For instance, you can try running: ```bash # Show the CLI -(graphnet) $ python examples/04_training/02_train_model_without_configs.py --help +(graphnet) $ python examples/04_training/02_train_dynedge_from_config.py --help # Train energy regression model -(graphnet) $ python examples/04_training/02_train_model_without_configs.py +(graphnet) $ python examples/04_training/02_train_dynedge_from_config.py ``` From 1fbf534c81bddf146ca4da2fe61d6bee94db50c2 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 13:08:43 +0200 Subject: [PATCH 15/26] Move perturbations to graph_definition --- .../models/graphs/graph_definition.py | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 0a87a301e..8ad95ce58 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -6,10 +6,11 @@ """ -from typing import Any, List, Optional, Dict, Callable +from typing import Any, List, Optional, Dict, Callable, Union import torch from torch_geometric.data import Data import numpy as np +from numpy.random import default_rng, Generator from graphnet.models.detector import Detector from .edges import EdgeDefinition @@ -27,6 +28,8 @@ def __init__( edge_definition: Optional[EdgeDefinition] = None, node_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, ): """Construct ´GraphDefinition´. The ´detector´ holds. @@ -43,6 +46,12 @@ def __init__( edge_definition: Definition of edges. Defaults to None. node_feature_names: Names of node feature columns. Defaults to None dtype: data type used for node features. e.g. ´torch.float´ + perturbation_dict: Dictionary mapping a feature name to a standard + deviation according to which the values for this + feature should be randomly perturbed. Defaults + to None. + seed: seed or Generator used to randomly sample perturbations. + Defaults to None. """ # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @@ -51,6 +60,8 @@ def __init__( self._detector = detector self._edge_definition = edge_definition self._node_definition = node_definition + self._perturbation_dict = perturbation_dict + if node_feature_names is None: # Assume all features in Detector is used. node_feature_names = list(self._detector.feature_map().keys()) # type: ignore @@ -66,6 +77,24 @@ def __init__( self.nb_inputs = len(self._node_feature_names) self.nb_outputs = self._node_definition.nb_outputs + # Set perturbation_cols if needed + if isinstance(self._perturbation_dict, dict): + self._perturbation_cols = [ + self._node_feature_names.index(key) + for key in self._perturbation_dict.keys() + ] + if seed is not None: + if isinstance(seed, int): + self.rng = default_rng(seed) + elif isinstance(seed, Generator): + self.rng = seed + else: + raise ValueError( + "Invalid seed. Must be an int or a numpy Generator." + ) + else: + self.rng = default_rng() + def forward( # type: ignore self, node_features: np.ndarray, @@ -97,6 +126,9 @@ def forward( # type: ignore node_features=node_features, node_feature_names=node_feature_names ) + # Gaussian perturbation of each column if perturbation dict is given + node_features = self._perturb_input(node_features) + # Transform to pytorch tensor node_features = torch.tensor(node_features, dtype=self.dtype) @@ -164,6 +196,20 @@ def _validate_input( node_feature_names[idx] == self._node_feature_names[idx] ), f""" Order of node features in data are not the same as expected. Got {node_feature_names} vs. {self._node_feature_names}""" + def _perturb_input(self, node_features: np.ndarray) -> np.ndarray: + if isinstance(self._perturbation_dict, dict): + self.warning_once( + f"""Will randomly perturb {list(self._perturbation_dict.keys())} using standard diviations {self._perturbation_dict.values}""" + ) + perturbed_features = self.rng.normal( + loc=node_features[:, self._perturbation_cols], + scale=np.array( + list(self._perturbation_dict.values()), dtype=np.float + ), + ) + node_features[:, self._perturbation_cols] = perturbed_features + return node_features + def _add_loss_weights( self, graph: Data, From d4e166a19724bb2b885d234248b87a4eae791184 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 13:58:23 +0200 Subject: [PATCH 16/26] minor adjustments, unit test --- src/graphnet/models/graphs/graph_definition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 8ad95ce58..01197fc46 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -199,7 +199,7 @@ def _validate_input( def _perturb_input(self, node_features: np.ndarray) -> np.ndarray: if isinstance(self._perturbation_dict, dict): self.warning_once( - f"""Will randomly perturb {list(self._perturbation_dict.keys())} using standard diviations {self._perturbation_dict.values}""" + f"""Will randomly perturb {list(self._perturbation_dict.keys())} using stds {self._perturbation_dict.values}""" ) perturbed_features = self.rng.normal( loc=node_features[:, self._perturbation_cols], From e93f5147d1568484799c36c125f5f4f3434e8f50 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 14:00:33 +0200 Subject: [PATCH 17/26] Unit tests --- .../models/graphs/graph_definition.py | 2 +- src/graphnet/models/graphs/graphs.py | 13 ++++- tests/models/test_graph_definition.py | 53 +++++++++++++++++++ 3 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 tests/models/test_graph_definition.py diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 01197fc46..4741be7eb 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -199,7 +199,7 @@ def _validate_input( def _perturb_input(self, node_features: np.ndarray) -> np.ndarray: if isinstance(self._perturbation_dict, dict): self.warning_once( - f"""Will randomly perturb {list(self._perturbation_dict.keys())} using stds {self._perturbation_dict.values}""" + f"""Will randomly perturb {list(self._perturbation_dict.keys())} using stds {self._perturbation_dict.values()}""" ) perturbed_features = self.rng.normal( loc=node_features[:, self._perturbation_cols], diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index a48b33a0d..4ae53037a 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -1,7 +1,8 @@ """A module containing different graph representations in GraphNeT.""" -from typing import List, Optional +from typing import List, Optional, Dict, Union import torch +from numpy.random import Generator from .graph_definition import GraphDefinition from graphnet.models.detector import Detector @@ -18,6 +19,8 @@ def __init__( node_definition: NodeDefinition = NodesAsPulses(), node_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, nb_nearest_neighbours: int = 8, columns: List[int] = [0, 1, 2], ) -> None: @@ -28,6 +31,12 @@ def __init__( node_definition: Definition of nodes in the graph. node_feature_names: Name of node features. dtype: data type for node features. + perturbation_dict: Dictionary mapping a feature name to a standard + deviation according to which the values for this + feature should be randomly perturbed. Defaults + to None. + seed: seed or Generator used to randomly sample perturbations. + Defaults to None. nb_nearest_neighbours: Number of edges for each node. Defaults to 8. columns: node feature columns used for distance calculation . Defaults to [0, 1, 2]. @@ -42,4 +51,6 @@ def __init__( ), dtype=dtype, node_feature_names=node_feature_names, + perturbation_dict=perturbation_dict, + seed=seed, ) diff --git a/tests/models/test_graph_definition.py b/tests/models/test_graph_definition.py new file mode 100644 index 000000000..bf16d7853 --- /dev/null +++ b/tests/models/test_graph_definition.py @@ -0,0 +1,53 @@ +"""Unit tests for GraphDefinition.""" + +from graphnet.models.graphs import KNNGraph +from graphnet.models.detector.prometheus import Prometheus +from graphnet.data.constants import FEATURES + +import numpy as np +from copy import deepcopy +import torch + + +def test_graph_definition() -> None: + """Tests the forward pass of GraphDefinition.""" + # Test configuration + features = FEATURES.PROMETHEUS + perturbation_dict = { + "sensor_pos_x": 1.4, + "sensor_pos_y": 2.2, + "sensor_pos_z": 3.7, + "t": 1.2, + } + mock_data = np.array([[1, 5, 2, 3], [2, 9, 6, 2]]) + seed = 42 + n_reps = 5 + + graph_definition = KNNGraph( + detector=Prometheus(), perturbation_dict=perturbation_dict, seed=seed + ) + original_output = graph_definition( + node_features=deepcopy(mock_data), node_feature_names=features + ) + + for _ in range(n_reps): + graph_definition_perturbed = KNNGraph( + detector=Prometheus(), perturbation_dict=perturbation_dict + ) + + graph_definition = KNNGraph( + detector=Prometheus(), + perturbation_dict=perturbation_dict, + seed=seed, + ) + + data = graph_definition( + node_features=deepcopy(mock_data), node_feature_names=features + ) + + perturbed_data = graph_definition_perturbed( + node_features=deepcopy(mock_data), node_feature_names=features + ) + + assert ~torch.equal(data.x, perturbed_data.x) # should not be equal. + assert torch.equal(data.x, original_output.x) # should be equal. From b567581a5ee250df3f6ba2fe326def165ec5a95e Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 14:01:14 +0200 Subject: [PATCH 18/26] delete perturbedsqlitedataset --- .../sqlite/sqlite_dataset_perturbed.py | 152 ------------------ 1 file changed, 152 deletions(-) delete mode 100644 src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py diff --git a/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py deleted file mode 100644 index b951e6916..000000000 --- a/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py +++ /dev/null @@ -1,152 +0,0 @@ -"""`Dataset` class(es) for reading perturbed data from SQLite databases.""" - -from typing import Dict, List, Optional, Tuple, Union - -import numpy as np -from numpy.random import default_rng, Generator -import torch -from torch_geometric.data import Data - -from .sqlite_dataset import SQLiteDataset - - -class SQLiteDatasetPerturbed(SQLiteDataset): - """Pytorch dataset for reading perturbed data from SQLite databases. - - This including a pre-processing step, where the input data is randomly - perturbed according to given per-feature "noise" levels. This is intended - to test the stability of a trained model under small changes to the input - parameters. - """ - - def __init__( - self, - path: Union[str, List[str]], - pulsemaps: Union[str, List[str]], - features: List[str], - truth: List[str], - *, - perturbation_dict: Dict[str, float], - node_truth: Optional[List[str]] = None, - index_column: str = "event_no", - truth_table: str = "truth", - node_truth_table: Optional[str] = None, - string_selection: Optional[List[int]] = None, - selection: Optional[List[int]] = None, - dtype: torch.dtype = torch.float32, - loss_weight_table: Optional[str] = None, - loss_weight_column: Optional[str] = None, - loss_weight_default_value: Optional[float] = None, - seed: Optional[Union[int, Generator]] = None, - ): - """Construct SQLiteDatasetPerturbed. - - Args: - path: Path to the file(s) from which this `Dataset` should read. - pulsemaps: Name(s) of the pulse map series that should be used to - construct the nodes on the individual graph objects, and their - features. Multiple pulse series maps can be used, e.g., when - different DOM types are stored in different maps. - features: List of columns in the input files that should be used as - node features on the graph objects. - truth: List of event-level columns in the input files that should - be used added as attributes on the graph objects. - perturbation_dict (Dict[str, float]): Dictionary mapping a feature - name to a standard deviation according to which the values for - this feature should be randomly perturbed. - node_truth: List of node-level columns in the input files that - should be used added as attributes on the graph objects. - index_column: Name of the column in the input files that contains - unique indicies to identify and map events across tables. - truth_table: Name of the table containing event-level truth - information. - node_truth_table: Name of the table containing node-level truth - information. - string_selection: Subset of strings for which data should be read - and used to construct graph objects. Defaults to None, meaning - all strings for which data exists are used. - selection: List of indicies (in `index_column`) of the events in - the input files that should be read. Defaults to None, meaning - that all events in the input files are read. - dtype: Type of the feature tensor on the graph objects returned. - loss_weight_table: Name of the table containing per-event loss - weights. - loss_weight_column: Name of the column in `loss_weight_table` - containing per-event loss weights. This is also the name of the - corresponding attribute assigned to the graph object. - loss_weight_default_value: Default per-event loss weight. - NOTE: This default value is only applied when - `loss_weight_table` and `loss_weight_column` are specified, and - in this case to events with no value in the corresponding - table/column. That is, if no per-event loss weight table/column - is provided, this value is ignored. Defaults to None. - seed: Optional seed for random number generation. Defaults to None. - """ - # Base class constructor - super().__init__( - path=path, - pulsemaps=pulsemaps, - features=features, - truth=truth, - node_truth=node_truth, - index_column=index_column, - truth_table=truth_table, - node_truth_table=node_truth_table, - string_selection=string_selection, - selection=selection, - dtype=dtype, - loss_weight_table=loss_weight_table, - loss_weight_column=loss_weight_column, - loss_weight_default_value=loss_weight_default_value, - ) - - # Custom member variables - assert isinstance(perturbation_dict, dict) - assert len(set(perturbation_dict.keys())) == len( - perturbation_dict.keys() - ) - self._perturbation_dict = perturbation_dict - - self._perturbation_cols = [ - self._features.index(key) for key in self._perturbation_dict.keys() - ] - - if seed is not None: - if isinstance(seed, int): - self.rng = default_rng(seed) - elif isinstance(seed, Generator): - self.rng = seed - else: - raise ValueError( - "Invalid seed. Must be an int or a numpy Generator." - ) - else: - self.rng = default_rng() - - def __getitem__(self, sequential_index: int) -> Data: - """Return graph `Data` object at `index`.""" - if not (0 <= sequential_index < len(self)): - raise IndexError( - f"Index {sequential_index} not in range [0, {len(self) - 1}]" - ) - features, truth, node_truth, loss_weight = self._query( - sequential_index - ) - perturbed_features = self._perturb_features(features) - graph = self._create_graph( - perturbed_features, truth, node_truth, loss_weight - ) - return graph - - def _perturb_features( - self, features: List[Tuple[float, ...]] - ) -> List[Tuple[float, ...]]: - features_array = np.array(features) - perturbed_features = self.rng.normal( - loc=features_array[:, self._perturbation_cols], - scale=np.array( - list(self._perturbation_dict.values()), dtype=np.float - ), - ) - features_array[:, self._perturbation_cols] = perturbed_features - return features_array.tolist() From 8c54c77c5d77516734388631ca861738a22ed23b Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 14:13:06 +0200 Subject: [PATCH 19/26] remove old import statements --- src/graphnet/data/dataset/__init__.py | 1 - src/graphnet/data/dataset/sqlite/__init__.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/graphnet/data/dataset/__init__.py b/src/graphnet/data/dataset/__init__.py index 3ccdd9642..f6eafee94 100644 --- a/src/graphnet/data/dataset/__init__.py +++ b/src/graphnet/data/dataset/__init__.py @@ -7,7 +7,6 @@ from .dataset import EnsembleDataset, Dataset, ColumnMissingException from .parquet.parquet_dataset import ParquetDataset from .sqlite.sqlite_dataset import SQLiteDataset - from .sqlite.sqlite_dataset_perturbed import SQLiteDatasetPerturbed torch.multiprocessing.set_sharing_strategy("file_system") diff --git a/src/graphnet/data/dataset/sqlite/__init__.py b/src/graphnet/data/dataset/sqlite/__init__.py index 84d67a921..c44d66184 100644 --- a/src/graphnet/data/dataset/sqlite/__init__.py +++ b/src/graphnet/data/dataset/sqlite/__init__.py @@ -3,6 +3,5 @@ if has_torch_package(): from .sqlite_dataset import SQLiteDataset - from .sqlite_dataset_perturbed import SQLiteDatasetPerturbed del has_torch_package From 6f014fa6f650bd1e98f11fee69b56c195ee509ba Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 19:28:54 +0200 Subject: [PATCH 20/26] replace np.float with float --- src/graphnet/models/graphs/graph_definition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 4741be7eb..4bac67cd5 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -204,7 +204,7 @@ def _perturb_input(self, node_features: np.ndarray) -> np.ndarray: perturbed_features = self.rng.normal( loc=node_features[:, self._perturbation_cols], scale=np.array( - list(self._perturbation_dict.values()), dtype=np.float + list(self._perturbation_dict.values()), dtype=float ), ) node_features[:, self._perturbation_cols] = perturbed_features From 726e65367a602aeabd1f1b828b9e74567f30f990 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 19:55:31 +0200 Subject: [PATCH 21/26] shorten warning --- src/graphnet/models/graphs/graph_definition.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 4bac67cd5..35eefd5f3 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -199,7 +199,9 @@ def _validate_input( def _perturb_input(self, node_features: np.ndarray) -> np.ndarray: if isinstance(self._perturbation_dict, dict): self.warning_once( - f"""Will randomly perturb {list(self._perturbation_dict.keys())} using stds {self._perturbation_dict.values()}""" + f"""Will randomly perturb +{list(self._perturbation_dict.keys())} +using stds {self._perturbation_dict.values()}""" ) perturbed_features = self.rng.normal( loc=node_features[:, self._perturbation_cols], From 822c0d7fde3c6f1f5834e0159178d1b6368e5a89 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 19:57:52 +0200 Subject: [PATCH 22/26] shorten doc string --- src/graphnet/models/graphs/graph_definition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 35eefd5f3..931f5e398 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -200,8 +200,8 @@ def _perturb_input(self, node_features: np.ndarray) -> np.ndarray: if isinstance(self._perturbation_dict, dict): self.warning_once( f"""Will randomly perturb -{list(self._perturbation_dict.keys())} -using stds {self._perturbation_dict.values()}""" + {list(self._perturbation_dict.keys())} + using stds {self._perturbation_dict.values()}""" # noqa ) perturbed_features = self.rng.normal( loc=node_features[:, self._perturbation_cols], From 1758b740d2fe21e501196ab2a08d2e2fed0e9325 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 20:01:00 +0200 Subject: [PATCH 23/26] shorten error strings --- .../models/graphs/graph_definition.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 931f5e398..089eee008 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -113,9 +113,12 @@ def forward( # type: ignore node_feature_names: name of each column. Shape ´[,d]´. truth_dicts: Dictionary containing truth labels. custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels. - loss_weight_column: Name of column that holds loss weight. Defaults to None. + loss_weight_column: Name of column that holds loss weight. + Defaults to None. loss_weight: Loss weight associated with event. Defaults to None. - loss_weight_default_value: default value for loss weight. Used in instances where some events have no pre-defined loss weight. Defaults to None. + loss_weight_default_value: default value for loss weight. + Used in instances where some events have + no pre-defined loss weight. Defaults to None. data_path: Path to dataset data files. Defaults to None. Returns: @@ -146,7 +149,8 @@ def forward( # type: ignore graph = self._edge_definition(graph) else: self.warnonce( - "No EdgeDefinition provided. Graphs will not have edges defined!" + """No EdgeDefinition provided. + Graphs will not have edges defined!""" # noqa ) # Attach data path - useful for Ensemble datasets. @@ -190,11 +194,15 @@ def _validate_input( # was instantiated with. assert len(node_feature_names) == len( self._node_feature_names - ), f"""Input features ({node_feature_names}) is not what {self.__class__.__name__} was instatiated with ({self._node_feature_names})""" + ), f"""Input features ({node_feature_names}) is not what + {self.__class__.__name__} was instatiated + with ({self._node_feature_names})""" # noqa for idx in range(len(node_feature_names)): assert ( node_feature_names[idx] == self._node_feature_names[idx] - ), f""" Order of node features in data are not the same as expected. Got {node_feature_names} vs. {self._node_feature_names}""" + ), f""" Order of node features in data + are not the same as expected. Got {node_feature_names} + vs. {self._node_feature_names}""" # noqa def _perturb_input(self, node_features: np.ndarray) -> np.ndarray: if isinstance(self._perturbation_dict, dict): @@ -298,7 +306,8 @@ def _add_features_individually( graph[feature] = graph.x[:, index].detach() else: self.warnonce( - """Cannot assign graph['x']. This field is reserved for node features. Please rename your input feature.""" + """Cannot assign graph['x']. This field is reserved + for node features. Please rename your input feature.""" # noqa ) return graph From afeec3e4dfec9dcd0e6df06e09e97763d0140f52 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 20:17:53 +0200 Subject: [PATCH 24/26] Replace GenericExtractor in 01-03 for FeatureExtractor --- examples/01_icetray/01_convert_i3_files.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/01_icetray/01_convert_i3_files.py b/examples/01_icetray/01_convert_i3_files.py index 03517f51b..495c9a13c 100644 --- a/examples/01_icetray/01_convert_i3_files.py +++ b/examples/01_icetray/01_convert_i3_files.py @@ -5,6 +5,7 @@ from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR from graphnet.data.extractors import ( I3FeatureExtractorIceCubeUpgrade, + I3FeatureExtractorIceCube86, I3RetroExtractor, I3TruthExtractor, I3GenericExtractor, @@ -34,12 +35,7 @@ def main_icecube86(backend: str) -> None: converter: DataConverter = CONVERTER_CLASS[backend]( [ - I3GenericExtractor( - keys=[ - "SRTInIcePulses", - "I3MCTree", - ] - ), + I3FeatureExtractorIceCube86("SRTInIcePulses"), I3TruthExtractor(), ], outdir, From 999e26d885a8590bd4fddf41b860c36747f3c73e Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 21:07:36 +0200 Subject: [PATCH 25/26] fix typo in readme.md --- examples/04_training/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/04_training/README.md b/examples/04_training/README.md index 1849515be..934767247 100644 --- a/examples/04_training/README.md +++ b/examples/04_training/README.md @@ -34,7 +34,7 @@ This subfolder contains two main training scripts: --prediction-names zenith_pred zenith_kappa_pred azimuth_pred azimuth_kappa_pred ``` -**`02_train_model_dynedge_from_config.py** Shows how to train a GNN on neutrino telescope data **using configuration files** to construct the dataset that loads the data and the model that is trained. This is the recommended way to configure standard dataset and models, as it is easier to ready and share than doing so in pure code. This example can be run using a few different models targeting different physics use cases. For instance, you can try running: +**`03_train_model_dynedge_from_config.py** Shows how to train a GNN on neutrino telescope data **using configuration files** to construct the dataset that loads the data and the model that is trained. This is the recommended way to configure standard dataset and models, as it is easier to ready and share than doing so in pure code. This example can be run using a few different models targeting different physics use cases. For instance, you can try running: ```bash # Show the CLI From 03dc5d2f8a54a21f65f1844967215bd0c02befd8 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sat, 23 Sep 2023 11:18:36 +0200 Subject: [PATCH 26/26] Update code comment in 04-01 --- examples/04_training/01_train_dynedge.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/04_training/01_train_dynedge.py b/examples/04_training/01_train_dynedge.py index 6a8bf2f15..58e513ec2 100644 --- a/examples/04_training/01_train_dynedge.py +++ b/examples/04_training/01_train_dynedge.py @@ -164,10 +164,13 @@ def main( # Save results as .csv results.to_csv(f"{path}/results.csv") - # Save full model (including weights) to .pth file - Not version proof + # Save full model (including weights) to .pth file - not version safe + # Note: Models saved as .pth files in one version of graphnet + # may not be compatible with a different version of graphnet. model.save(f"{path}/model.pth") # Save model config and state dict - Version safe save method. + # This method of saving models is the safest way. model.save_state_dict(f"{path}/state_dict.pth") model.save_config(f"{path}/model_config.yml")